diff --git a/Makefile b/Makefile index 786ea62..0b283d3 100644 --- a/Makefile +++ b/Makefile @@ -66,4 +66,5 @@ hyper: clean: rm -r results/test/test* + rm data/rk_mini.csv.gz_raw.h5 rm data/rk_mini.csv.gz.h5 diff --git a/arguments.py b/arguments.py index d6272c4..7b03890 100644 --- a/arguments.py +++ b/arguments.py @@ -105,9 +105,9 @@ def get_model_args(args): "embedding_model": os.path.join(model_path, "embd.h5"), "clf_model": os.path.join(model_path, "clf.h5"), "train_log": os.path.join(model_path, "train.log.csv"), - "train_h5data": args.train_data + ".h5", - "test_h5data": args.test_data + ".h5", - "future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred.h5") + "train_h5data": args.train_data, + "test_h5data": args.test_data, + "future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred") } for model_path in args.model_paths] def parse(): @@ -115,7 +115,7 @@ def parse(): args.embedding_model = os.path.join(args.model_path, "embd.h5") args.clf_model = os.path.join(args.model_path, "clf.h5") args.train_log = os.path.join(args.model_path, "train.log.csv") - args.train_h5data = args.train_data + ".h5" - args.test_h5data = args.test_data + ".h5" - args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred.h5") + args.train_h5data = args.train_data + args.test_h5data = args.test_data + args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred") return args diff --git a/dataset.py b/dataset.py index 9cf22f0..ba6a4aa 100644 --- a/dataset.py +++ b/dataset.py @@ -4,6 +4,7 @@ import string from multiprocessing import Pool import h5py +import joblib import numpy as np import pandas as pd from tqdm import tqdm @@ -139,14 +140,18 @@ def create_raw_dataset_from_flows(user_flow_df, char_dict, max_len, window_size= def store_h5dataset(path, data: dict): - f = h5py.File(path, "w") + f = h5py.File(path + ".h5", "w") for key, val in data.items(): f.create_dataset(key, data=val) f.close() +def check_h5dataset(path): + return open(path + ".h5", "r") + + def load_h5dataset(path): - f = h5py.File(path, "r") + f = h5py.File(path + ".h5", "r") data = {} for k in f.keys(): data[k] = f[k] @@ -225,17 +230,17 @@ def get_flow_per_user(df): def load_or_generate_h5data(h5data, train_data, domain_length, window_size): - char_dict = get_character_dict() logger.info(f"check for h5data {h5data}") try: - open(h5data, "r") + check_h5dataset(h5data) except FileNotFoundError: - logger.info("h5 data not found - load csv file") - user_flow_df = get_user_flow_data(train_data) - logger.info("create training dataset") - domain, flow, name, client, server = create_dataset_from_flows(user_flow_df, char_dict, - max_len=domain_length, - window_size=window_size) + logger.info("load raw training dataset") + domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data + "_raw", train_data, + domain_length, window_size) + logger.info("filter training dataset") + domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value, + name.value, hits.value, + trusted_hits.value, server.value) logger.info("store training dataset as h5 file") data = { "domain": domain.astype(np.int8), @@ -250,6 +255,32 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size): return data["domain"], data["flow"], data["name"], data["client"], data["server"] +def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size): + char_dict = get_character_dict() + logger.info(f"check for h5data {h5data}") + try: + check_h5dataset(h5data) + except FileNotFoundError: + logger.info("h5 data not found - load csv file") + user_flow_df = get_user_flow_data(train_data) + logger.info("create raw training dataset") + domain, flow, name, hits, trusted_hits, server = create_raw_dataset_from_flows(user_flow_df, char_dict, + domain_length, window_size) + logger.info("store raw training dataset as h5 file") + data = { + "domain": domain.astype(np.int8), + "flow": flow, + "name": name, + "hits_vt": hits.astype(np.int8), + "hits_trusted": hits.astype(np.int8), + "server": server.astype(np.bool) + } + store_h5dataset(h5data, data) + logger.info("load h5 dataset") + data = load_h5dataset(h5data) + return data["domain"], data["flow"], data["name"], data["hits_vt"], data["hits_trusted"], data["server"] + + def generate_names(train_data, window_size): user_flow_df = get_user_flow_data(train_data) with Pool() as pool: @@ -291,13 +322,9 @@ def load_or_generate_domains(train_data, domain_length): return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix().astype(bool) -def save_predictions(path, c_pred, s_pred): - f = h5py.File(path, "w") - f.create_dataset("client", data=c_pred) - f.create_dataset("server", data=s_pred) - f.close() +def save_predictions(path, results): + joblib.dump(results, path + "/results.joblib", compress=3) def load_predictions(path): - f = h5py.File(path, "r") - return f["client"], f["server"] + return joblib.load(path + "/results.joblib") diff --git a/main.py b/main.py index 08f88af..b5b8d41 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,12 @@ import json import logging import os -import joblib import numpy as np import pandas as pd import tensorflow as tf -from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping -from keras.models import load_model, Model +from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint +from keras.models import Model, load_model import arguments import dataset @@ -15,9 +14,8 @@ import hyperband import models # create logger import visualize -from dataset import load_or_generate_h5data -from utils import exists_or_make_path, get_custom_class_weights from arguments import get_model_args +from utils import exists_or_make_path, get_custom_class_weights logger = logging.getLogger('logger') logger.setLevel(logging.DEBUG) @@ -115,8 +113,9 @@ def main_hyperband(): } logger.info("create training dataset") - domain_tr, flow_tr, name_tr, client_tr, server_tr = load_or_generate_h5data(args.train_h5data, args.train_data, - args.domain_length, args.window) + domain_tr, flow_tr, name_tr, client_tr, server_tr = dataset.load_or_generate_h5data(args.train_h5data, + args.train_data, + args.domain_length, args.window) hp = hyperband.Hyperband(params, [domain_tr, flow_tr], [client_tr, server_tr]) @@ -129,10 +128,10 @@ def main_train(param=None): exists_or_make_path(args.model_path) logger.info(f"Use command line arguments: {args}") - domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = load_or_generate_h5data(args.train_h5data, - args.train_data, - args.domain_length, - args.window) + domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.train_h5data, + args.train_data, + args.domain_length, + args.window) logger.info("define callbacks") callbacks = [] callbacks.append(ModelCheckpoint(filepath=args.clf_model, @@ -245,12 +244,12 @@ def main_train(param=None): def main_test(): logger.info("start test: load data") - domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, - args.test_data, - args.domain_length, - args.window) - domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) - + domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.test_h5data, + args.test_data, + args.domain_length, + args.window) + domain_encs, _ = dataset.load_or_generate_domains(args.test_data, args.domain_length) + for model_args in get_model_args(args): results = {} logger.info(f"process model {model_args['model_path']}") @@ -268,73 +267,67 @@ def main_test(): results["client_pred"] = pred else: results["server_pred"] = pred - # dataset.save_predictions(model_args["future_prediction"], c_pred, s_pred) embd_model = load_model(model_args["embedding_model"]) domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1) - # np.save(model_args["model_path"] + "/domain_embds.npy", domain_embeddings) - results["domain_embds"] = domain_embeddings - joblib.dump(results, model_args["model_path"] + "/results.joblib", compress=3) + + dataset.save_predictions(model_args["model_path"], results) def main_visualization(): - domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, - args.test_data, - args.domain_length, - args.window) - # client_val, server_val = client_val.value, server_val.value + domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data, + args.test_data, + args.domain_length, + args.window) client_val = client_val.value - + logger.info("plot model") model = load_model(args.clf_model, custom_objects=models.get_metrics()) visualize.plot_model_as(model, os.path.join(args.model_path, "model.png")) - - try: - logger.info("plot training curve") - logs = pd.read_csv(args.train_log) - if args.model_output == "client": - visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path)) - else: - visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path)) - visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path)) - except Exception as e: - logger.warning(f"could not generate training curves: {e}") - - client_pred, server_pred = dataset.load_predictions(args.future_prediction) - client_pred, server_pred = client_pred.value.flatten(), server_pred.value.flatten() + + logger.info("plot training curve") + logs = pd.read_csv(args.train_log) + if "acc" in logs.keys(): + visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path)) + elif "client_acc" in logs.keys() and "server_acc" in logs.keys(): + visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path)) + visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path)) + else: + logger.warning("Error while plotting training curves") + + results = dataset.load_predictions(args.future_prediction) + client_pred = results["client_pred"].flatten() + logger.info("plot pr curve") visualize.plot_clf() visualize.plot_precision_recall(client_val, client_pred, args.model_path) visualize.plot_legend() visualize.plot_save("{}/window_client_prc.png".format(args.model_path)) - # visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path)) - # visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path)) - # visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path)) + logger.info("plot roc curve") visualize.plot_clf() visualize.plot_roc_curve(client_val, client_pred, args.model_path) visualize.plot_legend() visualize.plot_save("{}/window_client_roc.png".format(args.model_path)) - # visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path)) - + print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}") - + df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val}) user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float) df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred}) user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) - + visualize.plot_clf() visualize.plot_precision_recall(user_vals, user_preds, args.model_path) visualize.plot_legend() visualize.plot_save("{}/user_client_prc.png".format(args.model_path)) - + visualize.plot_clf() visualize.plot_roc_curve(user_vals, user_preds, args.model_path) visualize.plot_legend() visualize.plot_save("{}/user_client_roc.png".format(args.model_path)) - + visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(), "{}/client_cov.png".format(args.model_path), normalize=False, title="Client Confusion Matrix") @@ -348,44 +341,48 @@ def main_visualization(): def main_visualize_all(): - domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, - args.test_data, - args.domain_length, - args.window) + domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data, + args.test_data, + args.domain_length, + args.window) logger.info("plot pr curves") visualize.plot_clf() for model_args in get_model_args(args): - client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) - visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"]) + results = dataset.load_predictions(model_args["future_prediction"]) + client_pred = results["client_pred"].flatten() + visualize.plot_precision_recall(client_val, client_pred, model_args["model_path"]) visualize.plot_legend() visualize.plot_save(f"{args.output_prefix}_window_client_prc.png") - + logger.info("plot roc curves") visualize.plot_clf() for model_args in get_model_args(args): - client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) - visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"]) + results = dataset.load_predictions(model_args["future_prediction"]) + client_pred = results["client_pred"].flatten() + visualize.plot_roc_curve(client_val, client_pred, model_args["model_path"]) visualize.plot_legend() visualize.plot_save(f"{args.output_prefix}_window_client_roc.png") - + df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val}) user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float) - + logger.info("plot user pr curves") visualize.plot_clf() for model_args in get_model_args(args): - client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) - df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()}) + results = dataset.load_predictions(model_args["future_prediction"]) + client_pred = results["client_pred"].flatten() + df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred}) user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"]) visualize.plot_legend() visualize.plot_save(f"{args.output_prefix}_user_client_prc.png") - + logger.info("plot user roc curves") visualize.plot_clf() for model_args in get_model_args(args): - client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) - df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()}) + results = dataset.load_predictions(model_args["future_prediction"]) + client_pred = results["client_pred"].flatten() + df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred}) user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"]) visualize.plot_legend()