From 70d00efb01e5683a3275c7c7e7ae29c697eb432b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Fri, 8 Sep 2017 13:55:13 +0200 Subject: [PATCH] refactor using joblib for test results, make h5py store/load more flexible --- Makefile | 1 + dataset.py | 66 +++++++++++++++++++++++++++++++++++++------------- main.py | 70 ++++++++++++++++++++++++------------------------------ 3 files changed, 82 insertions(+), 55 deletions(-) diff --git a/Makefile b/Makefile index b8c6916..4984f3e 100644 --- a/Makefile +++ b/Makefile @@ -34,3 +34,4 @@ hyper: clean: rm -r results/test* + rm data/rk_mini.csv.gz.h5 diff --git a/dataset.py b/dataset.py index b435151..a9c6851 100644 --- a/dataset.py +++ b/dataset.py @@ -126,22 +126,48 @@ def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10): return domain, flow, names, client_tr, server -def store_h5dataset(path, domain, flow, name, client, server): +def create_testset_from_flows(user_flow_df, char_dict, max_len, window_size=10): + logger.info("get chunks from user data frames") + with Pool() as pool: + results = [] + for user_flow in tqdm(get_flow_per_user(user_flow_df), total=len(user_flow_df['user_hash'].unique().tolist())): + results.append(pool.apply_async(get_user_chunks, (user_flow, window_size))) + windows = [window for res in results for window in res.get()] + logger.info("create training dataset") + domain, flow, hits, names, server, trusted_hits = create_dataset_from_lists(chunks=windows, + vocab=char_dict, + max_len=max_len) + # make client labels discrete with 4 different values + hits = np.apply_along_axis(lambda x: discretize_label(x, 3), 0, np.atleast_2d(hits)) + # select only 1.0 and 0.0 from training data + pos_idx = np.where(np.logical_or(hits == 1.0, trusted_hits >= 1.0))[0] + neg_idx = np.where(hits == 0.0)[0] + idx = np.concatenate((pos_idx, neg_idx)) + # choose selected sample to train on + domain = domain[idx] + flow = flow[idx] + client_tr = np.zeros_like(idx, float) + client_tr[:pos_idx.shape[-1]] = 1.0 + server = server[idx] + names = names[idx] + + return domain, flow, names, client_tr, server + + +def store_h5dataset(path, data: dict): f = h5py.File(path, "w") - domain = domain.astype(np.int8) - f.create_dataset("domain", data=domain) - f.create_dataset("flow", data=flow) - f.create_dataset("name", data=name) - server = server.astype(np.bool) - client = client.astype(np.bool) - f.create_dataset("client", data=client) - f.create_dataset("server", data=server) + for key, val in data.items(): + f.create_dataset(key, data=val) f.close() def load_h5dataset(path): - data = h5py.File(path, "r") - return data["domain"], data["flow"], data["name"], data["client"], data["server"] + f = h5py.File(path, "r") + keys = f.keys() + data = {} + for k in keys: + data[k] = f[k] + return data def create_dataset_from_lists(chunks, vocab, max_len): @@ -224,13 +250,21 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size): logger.info("h5 data not found - load csv file") user_flow_df = get_user_flow_data(train_data) logger.info("create training dataset") - domain, flow, names, client, server = create_dataset_from_flows(user_flow_df, char_dict, - max_len=domain_length, - window_size=window_size) + domain, flow, name, client, server = create_dataset_from_flows(user_flow_df, char_dict, + max_len=domain_length, + window_size=window_size) logger.info("store training dataset as h5 file") - store_h5dataset(h5data, domain, flow, names, client, server) + data = { + "domain": domain.astype(np.int8), + "flow": flow, + "name": name, + "client": client.astype(np.bool), + "server": server.astype(np.bool) + } + store_h5dataset(h5data, data) logger.info("load h5 dataset") - return load_h5dataset(h5data) + data = load_h5dataset(h5data) + return data["domain"], data["flow"], data["name"], data["client"], data["server"] def generate_names(train_data, window_size): diff --git a/main.py b/main.py index 53fe16f..2fe7467 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ import json import logging import os +import joblib import numpy as np import pandas as pd @@ -78,6 +79,15 @@ PARAMS = { } +def create_model(model, output_type): + if output_type == "both": + return Model(inputs=[model.in_domains, model.in_flows], outputs=(model.out_client, model.out_server)) + elif output_type == "client": + return Model(inputs=[model.in_domains, model.in_flows], outputs=(model.out_client,)) + else: + raise Exception("unknown model output") + + def main_paul_best(): pauls_best_params = models.pauls_networks.best_config main_train(pauls_best_params) @@ -154,11 +164,7 @@ def main_train(param=None): logger.info(f"Generator model with params: {param}") embedding, model, new_model = models.get_models_by_params(param) - if args.model_output == "both": - model = Model(inputs=[new_model.in_domains, new_model.in_flows], - outputs=(new_model.out_server, new_model.out_client)) - else: - raise Exception("unknown model output") + model = create_model(new_model, args.model_output) server_tr = np.expand_dims(server_windows_tr, 2) logger.info("compile and train model") @@ -202,15 +208,8 @@ def main_train(param=None): logger.info(f"Generator model with params: {param}") embedding, model, new_model = models.get_models_by_params(param) - if args.model_output == "both": - model = Model(inputs=[model.in_domains, model.in_flows], outputs=(model.out_client, model.out_server)) - new_model = Model(inputs=[new_model.in_domains, new_model.in_flows], - outputs=(new_model.out_client, new_model.out_server)) - elif args.model_output == "client": - model = Model(inputs=[model.in_domains, model.in_flows], outputs=(model.out_client,)) - new_model = Model(inputs=[new_model.in_domains, new_model.in_flows], outputs=(new_model.out_client,)) - else: - raise Exception("unknown model output") + model = create_model(model, args.model_output) + new_model = create_model(new_model, args.model_output) if args.model_type == "inter": server_tr = np.expand_dims(server_windows_tr, 2) @@ -253,6 +252,7 @@ def main_test(): domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) for model_args in get_model_args(args): + results = {} logger.info(f"process model {model_args['model_path']}") clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics()) @@ -262,17 +262,20 @@ def main_test(): if args.model_output == "both": c_pred, s_pred = pred + results["client_pred"] = c_pred + results["server_pred"] = s_pred elif args.model_output == "client": - c_pred = pred - s_pred = np.zeros(0) + results["client_pred"] = pred else: - c_pred = np.zeros(0) - s_pred = pred - dataset.save_predictions(model_args["future_prediction"], c_pred, s_pred) + results["server_pred"] = pred + # dataset.save_predictions(model_args["future_prediction"], c_pred, s_pred) embd_model = load_model(model_args["embedding_model"]) domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1) - np.save(model_args["model_path"] + "/domain_embds.npy", domain_embeddings) + # np.save(model_args["model_path"] + "/domain_embds.npy", domain_embeddings) + + results["domain_embds"] = domain_embeddings + joblib.dump(results, model_args["model_path"] + "results.joblib", compress=3) def main_visualization(): @@ -302,14 +305,16 @@ def main_visualization(): client_pred, server_pred = client_pred.value.flatten(), server_pred.value.flatten() logger.info("plot pr curve") visualize.plot_clf() - visualize.plot_precision_recall(client_val, client_pred) + visualize.plot_precision_recall(client_val, client_pred, args.model_path) + visualize.plot_legend() visualize.plot_save("{}/window_client_prc.png".format(args.model_path)) # visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path)) # visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path)) # visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path)) logger.info("plot roc curve") visualize.plot_clf() - visualize.plot_roc_curve(client_val, client_pred) + visualize.plot_roc_curve(client_val, client_pred, args.model_path) + visualize.plot_legend() visualize.plot_save("{}/window_client_roc.png".format(args.model_path)) # visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path)) @@ -321,11 +326,13 @@ def main_visualization(): user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) visualize.plot_clf() - visualize.plot_precision_recall(user_vals, user_preds) + visualize.plot_precision_recall(user_vals, user_preds, args.model_path) + visualize.plot_legend() visualize.plot_save("{}/user_client_prc.png".format(args.model_path)) visualize.plot_clf() - visualize.plot_roc_curve(user_vals, user_preds) + visualize.plot_roc_curve(user_vals, user_preds, args.model_path) + visualize.plot_legend() visualize.plot_save("{}/user_client_roc.png".format(args.model_path)) visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(), @@ -385,19 +392,6 @@ def main_visualize_all(): visualize.plot_save(f"{args.output_prefix}_user_client_roc.png") -def main_data(): - char_dict = dataset.get_character_dict() - user_flow_df = dataset.get_user_flow_data(args.train_data) - logger.info("create training dataset") - domain_tr, flow_tr, name_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict, - max_len=args.domain_length, - window_size=args.window) - print(f"domain shape {domain_tr.shape}") - print(f"flow shape {flow_tr.shape}") - print(f"client shape {client_tr.shape}") - print(f"server shape {server_tr.shape}") - - def main(): if "train" == args.mode: main_train() @@ -411,8 +405,6 @@ def main(): main_visualize_all() if "paul" == args.mode: main_paul_best() - if "data" == args.mode: - main_data() if __name__ == "__main__":