fix lazy domain loading and generation process
This commit is contained in:
parent
7f1d13658f
commit
6e7dc1297c
34
dataset.py
34
dataset.py
@ -152,6 +152,7 @@ def create_dataset_from_lists(chunks, vocab, max_len):
|
||||
:param max_len:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def get_domain_features_reduced(d):
|
||||
return get_domain_features(d[0], vocab, max_len)
|
||||
|
||||
@ -230,33 +231,40 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
||||
return load_h5dataset(h5data)
|
||||
|
||||
|
||||
# TODO: implement csv loading if already generated
|
||||
def load_or_generate_domains(train_data, domain_length):
|
||||
fn = f"{train_data}_domains.gz"
|
||||
|
||||
try:
|
||||
user_flow_df = pd.read_csv(fn)
|
||||
except Exception:
|
||||
char_dict = get_character_dict()
|
||||
user_flow_df = get_user_flow_data(train_data)
|
||||
user_flow_df.reset_index(inplace=True)
|
||||
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0,
|
||||
how="any")
|
||||
user_flow_df = user_flow_df.groupby(user_flow_df.domain).mean()
|
||||
user_flow_df.reset_index(inplace=True)
|
||||
|
||||
user_flow_df["clientLabel"] = np.where(
|
||||
np.logical_or(user_flow_df.trustedHits > 0, user_flow_df.virusTotalHits >= 3), True, False)
|
||||
user_flow_df[["serverLabel", "clientLabel"]] = user_flow_df[["serverLabel", "clientLabel"]].astype(bool)
|
||||
user_flow_df = user_flow_df[["domain", "serverLabel", "clientLabel"]]
|
||||
|
||||
user_flow_df.to_csv(fn, compression="gzip")
|
||||
|
||||
domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, char_dict, domain_length))
|
||||
domain_encs = np.stack(domain_encs)
|
||||
|
||||
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0, how="any")
|
||||
user_flow_df.reset_index(inplace=True)
|
||||
user_flow_df["clientLabel"] = np.where(
|
||||
np.logical_or(user_flow_df.trustedHits > 0, user_flow_df.virusTotalHits >= 3), 1.0, 0.0)
|
||||
user_flow_df = user_flow_df[["domain", "serverLabel", "clientLabel"]]
|
||||
user_flow_df.groupby(user_flow_df.domain).mean()
|
||||
|
||||
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix()
|
||||
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix().astype(bool)
|
||||
|
||||
|
||||
def save_predictions(path, c_pred, s_pred, embd, labels):
|
||||
def save_predictions(path, c_pred, s_pred):
|
||||
f = h5py.File(path, "w")
|
||||
f.create_dataset("client", data=c_pred)
|
||||
f.create_dataset("server", data=s_pred)
|
||||
f.create_dataset("embedding", data=embd)
|
||||
f.create_dataset("labels", data=labels)
|
||||
f.close()
|
||||
|
||||
|
||||
def load_predictions(path):
|
||||
f = h5py.File(path, "r")
|
||||
return f["client"], f["server"], f["embedding"], f["labels"]
|
||||
return f["client"], f["server"]
|
||||
|
14
main.py
14
main.py
@ -194,13 +194,12 @@ def main_test():
|
||||
else:
|
||||
c_pred = np.zeros(0)
|
||||
s_pred = pred
|
||||
dataset.save_predictions(args.future_prediction, c_pred, s_pred)
|
||||
|
||||
model = load_model(args.embedding_model)
|
||||
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||
domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||
|
||||
dataset.save_predictions(args.future_prediction, c_pred, s_pred, domain_embedding, labels)
|
||||
|
||||
np.save(args.model_path + "/domain_embds.npy", domain_embedding)
|
||||
|
||||
|
||||
def main_visualization():
|
||||
@ -213,12 +212,15 @@ def main_visualization():
|
||||
try:
|
||||
logger.info("plot training curve")
|
||||
logs = pd.read_csv(args.train_log)
|
||||
visualize.plot_training_curve(logs, "client", "{}/client_train.png".format(args.model_path))
|
||||
visualize.plot_training_curve(logs, "server", "{}/server_train.png".format(args.model_path))
|
||||
if args.model_output == "client":
|
||||
visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
|
||||
else:
|
||||
visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
|
||||
visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
|
||||
except Exception as e:
|
||||
logger.warning(f"could not generate training curves: {e}")
|
||||
|
||||
client_pred, server_pred, domain_embedding, labels = dataset.load_predictions(args.future_prediction)
|
||||
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
|
||||
client_pred, server_pred = client_pred.value, server_pred.value
|
||||
logger.info("plot pr curve")
|
||||
visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path))
|
||||
|
@ -132,11 +132,11 @@ def plot_confusion_matrix(y_true, y_pred, path,
|
||||
|
||||
def plot_training_curve(logs, key, path, dpi=600):
|
||||
plt.clf()
|
||||
plt.plot(logs[f"{key}_acc"], label="accuracy")
|
||||
plt.plot(logs[f"{key}_f1_score"], label="f1_score")
|
||||
plt.plot(logs[f"{key}acc"], label="accuracy")
|
||||
plt.plot(logs[f"{key}f1_score"], label="f1_score")
|
||||
|
||||
plt.plot(logs[f"val_{key}_acc"], label="accuracy")
|
||||
plt.plot(logs[f"val_{key}_f1_score"], label="val_f1_score")
|
||||
plt.plot(logs[f"val_{key}acc"], label="accuracy")
|
||||
plt.plot(logs[f"val_{key}f1_score"], label="val_f1_score")
|
||||
|
||||
plt.xlabel('epoch')
|
||||
plt.ylabel('percentage')
|
||||
|
Loading…
x
Reference in New Issue
Block a user