fix lazy domain loading and generation process

This commit is contained in:
René Knaebel 2017-08-03 12:27:17 +02:00
parent 7f1d13658f
commit 6e7dc1297c
3 changed files with 35 additions and 25 deletions

View File

@ -152,6 +152,7 @@ def create_dataset_from_lists(chunks, vocab, max_len):
:param max_len:
:return:
"""
def get_domain_features_reduced(d):
return get_domain_features(d[0], vocab, max_len)
@ -230,33 +231,40 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
return load_h5dataset(h5data)
# TODO: implement csv loading if already generated
def load_or_generate_domains(train_data, domain_length):
char_dict = get_character_dict()
user_flow_df = get_user_flow_data(train_data)
fn = f"{train_data}_domains.gz"
try:
user_flow_df = pd.read_csv(fn)
except Exception:
char_dict = get_character_dict()
user_flow_df = get_user_flow_data(train_data)
user_flow_df.reset_index(inplace=True)
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0,
how="any")
user_flow_df = user_flow_df.groupby(user_flow_df.domain).mean()
user_flow_df.reset_index(inplace=True)
user_flow_df["clientLabel"] = np.where(
np.logical_or(user_flow_df.trustedHits > 0, user_flow_df.virusTotalHits >= 3), True, False)
user_flow_df[["serverLabel", "clientLabel"]] = user_flow_df[["serverLabel", "clientLabel"]].astype(bool)
user_flow_df = user_flow_df[["domain", "serverLabel", "clientLabel"]]
user_flow_df.to_csv(fn, compression="gzip")
domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, char_dict, domain_length))
domain_encs = np.stack(domain_encs)
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0, how="any")
user_flow_df.reset_index(inplace=True)
user_flow_df["clientLabel"] = np.where(
np.logical_or(user_flow_df.trustedHits > 0, user_flow_df.virusTotalHits >= 3), 1.0, 0.0)
user_flow_df = user_flow_df[["domain", "serverLabel", "clientLabel"]]
user_flow_df.groupby(user_flow_df.domain).mean()
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix()
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix().astype(bool)
def save_predictions(path, c_pred, s_pred, embd, labels):
def save_predictions(path, c_pred, s_pred):
f = h5py.File(path, "w")
f.create_dataset("client", data=c_pred)
f.create_dataset("server", data=s_pred)
f.create_dataset("embedding", data=embd)
f.create_dataset("labels", data=labels)
f.close()
def load_predictions(path):
f = h5py.File(path, "r")
return f["client"], f["server"], f["embedding"], f["labels"]
return f["client"], f["server"]

14
main.py
View File

@ -194,13 +194,12 @@ def main_test():
else:
c_pred = np.zeros(0)
s_pred = pred
dataset.save_predictions(args.future_prediction, c_pred, s_pred)
model = load_model(args.embedding_model)
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
dataset.save_predictions(args.future_prediction, c_pred, s_pred, domain_embedding, labels)
np.save(args.model_path + "/domain_embds.npy", domain_embedding)
def main_visualization():
@ -213,12 +212,15 @@ def main_visualization():
try:
logger.info("plot training curve")
logs = pd.read_csv(args.train_log)
visualize.plot_training_curve(logs, "client", "{}/client_train.png".format(args.model_path))
visualize.plot_training_curve(logs, "server", "{}/server_train.png".format(args.model_path))
if args.model_output == "client":
visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
else:
visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
except Exception as e:
logger.warning(f"could not generate training curves: {e}")
client_pred, server_pred, domain_embedding, labels = dataset.load_predictions(args.future_prediction)
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
client_pred, server_pred = client_pred.value, server_pred.value
logger.info("plot pr curve")
visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path))

View File

@ -132,11 +132,11 @@ def plot_confusion_matrix(y_true, y_pred, path,
def plot_training_curve(logs, key, path, dpi=600):
plt.clf()
plt.plot(logs[f"{key}_acc"], label="accuracy")
plt.plot(logs[f"{key}_f1_score"], label="f1_score")
plt.plot(logs[f"{key}acc"], label="accuracy")
plt.plot(logs[f"{key}f1_score"], label="f1_score")
plt.plot(logs[f"val_{key}_acc"], label="accuracy")
plt.plot(logs[f"val_{key}_f1_score"], label="val_f1_score")
plt.plot(logs[f"val_{key}acc"], label="accuracy")
plt.plot(logs[f"val_{key}f1_score"], label="val_f1_score")
plt.xlabel('epoch')
plt.ylabel('percentage')