diff --git a/dataset.py b/dataset.py index 2fe5fd9..80444cc 100644 --- a/dataset.py +++ b/dataset.py @@ -152,6 +152,7 @@ def create_dataset_from_lists(chunks, vocab, max_len): :param max_len: :return: """ + def get_domain_features_reduced(d): return get_domain_features(d[0], vocab, max_len) @@ -230,33 +231,40 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size): return load_h5dataset(h5data) -# TODO: implement csv loading if already generated def load_or_generate_domains(train_data, domain_length): - char_dict = get_character_dict() - user_flow_df = get_user_flow_data(train_data) + fn = f"{train_data}_domains.gz" + + try: + user_flow_df = pd.read_csv(fn) + except Exception: + char_dict = get_character_dict() + user_flow_df = get_user_flow_data(train_data) + user_flow_df.reset_index(inplace=True) + user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0, + how="any") + user_flow_df = user_flow_df.groupby(user_flow_df.domain).mean() + user_flow_df.reset_index(inplace=True) + + user_flow_df["clientLabel"] = np.where( + np.logical_or(user_flow_df.trustedHits > 0, user_flow_df.virusTotalHits >= 3), True, False) + user_flow_df[["serverLabel", "clientLabel"]] = user_flow_df[["serverLabel", "clientLabel"]].astype(bool) + user_flow_df = user_flow_df[["domain", "serverLabel", "clientLabel"]] + + user_flow_df.to_csv(fn, compression="gzip") domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, char_dict, domain_length)) domain_encs = np.stack(domain_encs) - user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0, how="any") - user_flow_df.reset_index(inplace=True) - user_flow_df["clientLabel"] = np.where( - np.logical_or(user_flow_df.trustedHits > 0, user_flow_df.virusTotalHits >= 3), 1.0, 0.0) - user_flow_df = user_flow_df[["domain", "serverLabel", "clientLabel"]] - user_flow_df.groupby(user_flow_df.domain).mean() - - return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix() + return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix().astype(bool) -def save_predictions(path, c_pred, s_pred, embd, labels): +def save_predictions(path, c_pred, s_pred): f = h5py.File(path, "w") f.create_dataset("client", data=c_pred) f.create_dataset("server", data=s_pred) - f.create_dataset("embedding", data=embd) - f.create_dataset("labels", data=labels) f.close() def load_predictions(path): f = h5py.File(path, "r") - return f["client"], f["server"], f["embedding"], f["labels"] + return f["client"], f["server"] diff --git a/main.py b/main.py index 7374dc8..dd76b04 100644 --- a/main.py +++ b/main.py @@ -194,13 +194,12 @@ def main_test(): else: c_pred = np.zeros(0) s_pred = pred + dataset.save_predictions(args.future_prediction, c_pred, s_pred) model = load_model(args.embedding_model) domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1) - - dataset.save_predictions(args.future_prediction, c_pred, s_pred, domain_embedding, labels) - + np.save(args.model_path + "/domain_embds.npy", domain_embedding) def main_visualization(): @@ -213,12 +212,15 @@ def main_visualization(): try: logger.info("plot training curve") logs = pd.read_csv(args.train_log) - visualize.plot_training_curve(logs, "client", "{}/client_train.png".format(args.model_path)) - visualize.plot_training_curve(logs, "server", "{}/server_train.png".format(args.model_path)) + if args.model_output == "client": + visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path)) + else: + visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path)) + visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path)) except Exception as e: logger.warning(f"could not generate training curves: {e}") - client_pred, server_pred, domain_embedding, labels = dataset.load_predictions(args.future_prediction) + client_pred, server_pred = dataset.load_predictions(args.future_prediction) client_pred, server_pred = client_pred.value, server_pred.value logger.info("plot pr curve") visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path)) diff --git a/visualize.py b/visualize.py index 812f171..6d1ec1f 100644 --- a/visualize.py +++ b/visualize.py @@ -132,11 +132,11 @@ def plot_confusion_matrix(y_true, y_pred, path, def plot_training_curve(logs, key, path, dpi=600): plt.clf() - plt.plot(logs[f"{key}_acc"], label="accuracy") - plt.plot(logs[f"{key}_f1_score"], label="f1_score") + plt.plot(logs[f"{key}acc"], label="accuracy") + plt.plot(logs[f"{key}f1_score"], label="f1_score") - plt.plot(logs[f"val_{key}_acc"], label="accuracy") - plt.plot(logs[f"val_{key}_f1_score"], label="val_f1_score") + plt.plot(logs[f"val_{key}acc"], label="accuracy") + plt.plot(logs[f"val_{key}f1_score"], label="val_f1_score") plt.xlabel('epoch') plt.ylabel('percentage')