From 27f4d086ebd1309b32737f38ddc4cde7ed65e490 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Fri, 10 Nov 2017 11:04:27 +0100 Subject: [PATCH] add matplotlib agg mode; update beta vis function according to test results --- main.py | 78 ++++++++++++++++++++++++++++------------------------ visualize.py | 4 +++ 2 files changed, 46 insertions(+), 36 deletions(-) diff --git a/main.py b/main.py index ee9c7f8..bf6830c 100644 --- a/main.py +++ b/main.py @@ -610,19 +610,21 @@ def main_beta(): domain_val, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window) - path, model_prefix = os.path.split(os.path.normpath(args.output_prefix)) + path, model_prefix = os.path.split(os.path.normpath(args.model_path)) + print(path, model_prefix) try: - results = joblib.load(f"{path}/curves.joblib") + curves = joblib.load(f"{path}/curves.joblib") + logger.info(f"load file {path}/curves.joblib successfully") except Exception: - results = {} - results[model_prefix] = {"all": {}} + curves = {} + logger.info(f"currently {len(curves)} models in file: {curves.keys()}") + curves[model_prefix] = {"all": {}} domains = domain_val.value.reshape(-1, 40) - domains = np.apply_along_axis(lambda d: "".join(map(dataset.decode_char, d)), 1, domains) + domains = np.apply_along_axis(lambda d: dataset.decode_domain(d), 1, domains) - def load_df(path): + def load_df(res): df_server = None - res = dataset.load_predictions(path) data = { "names": name_val, "client_pred": res["client_pred"].flatten(), "hits_vt": hits_vt, "hits_trusted": hits_trusted, @@ -645,6 +647,9 @@ def main_beta(): res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3) return res, df_server + + res = dataset.load_predictions(path) + model_keys = sorted(filter(lambda x: x.startswith("clf"), res.keys()), key=lambda x: int(x[4:-3])) client_preds = [] server_preds = [] @@ -653,8 +658,8 @@ def main_beta(): server_user_preds = [] server_domain_preds = [] server_domain_avg_preds = [] - for model_args in get_model_args(args): - df, df_server = load_df(model_args["model_path"]) + for model_name in model_keys: + df, df_server = load_df(res[model_name]) client_preds.append(df.client_pred.as_matrix()) if "server_val" in df.columns: server_preds.append(df.server_pred.as_matrix()) @@ -664,56 +669,57 @@ def main_beta(): server_domain_preds.append(df_domain.server_pred.as_matrix()) df_domain_avg = df_server.groupby(df_server.domain).rolling(10).mean() server_domain_avg_preds.append(df_domain_avg.server_pred.as_matrix()) - - results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(), - df.client_pred.as_matrix().round()) + + curves[model_prefix][model_name] = confusion_matrix(df.client_val.as_matrix(), + df.client_pred.as_matrix().round()) df_user = df.groupby(df.names).max() client_user_preds.append(df_user.client_pred.as_matrix()) if "server_val" in df.columns: server_user_preds.append(df_user.server_pred.as_matrix()) logger.info("plot client curves") - results[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds) - results[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds) - results[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(), + curves[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds) + curves[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds) + curves[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(), + client_user_preds) + curves[model_prefix]["all"]["client_user_roc"] = visualize.calc_roc_mean(df_user.client_val.as_matrix(), client_user_preds) - results[model_prefix]["all"]["client_user_roc"] = visualize.calc_roc_mean(df_user.client_val.as_matrix(), - client_user_preds) if "server_val" in df.columns: logger.info("plot server curves") - results[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(), + curves[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(), + server_preds) + curves[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(), server_preds) - results[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(), - server_preds) - results[model_prefix]["all"]["server_user_prc"] = visualize.calc_pr_mean(df_user.server_val.as_matrix(), + curves[model_prefix]["all"]["server_user_prc"] = visualize.calc_pr_mean(df_user.server_val.as_matrix(), + server_user_preds) + + curves[model_prefix]["all"]["server_user_roc"] = visualize.calc_roc_mean(df_user.server_val.as_matrix(), server_user_preds) - - results[model_prefix]["all"]["server_user_roc"] = visualize.calc_roc_mean(df_user.server_val.as_matrix(), - server_user_preds) if df_server is not None: logger.info("plot server flow curves") - results[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(), + curves[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(), + server_flow_preds) + curves[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(), server_flow_preds) - results[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(), - server_flow_preds) - results[model_prefix]["all"]["server_domain_prc"] = visualize.calc_pr_mean(df_domain.server_val.as_matrix(), + curves[model_prefix]["all"]["server_domain_prc"] = visualize.calc_pr_mean(df_domain.server_val.as_matrix(), + server_domain_preds) + curves[model_prefix]["all"]["server_domain_roc"] = visualize.calc_roc_mean(df_domain.server_val.as_matrix(), server_domain_preds) - results[model_prefix]["all"]["server_domain_roc"] = visualize.calc_roc_mean(df_domain.server_val.as_matrix(), - server_domain_preds) - results[model_prefix]["all"]["server_domain_avg_prc"] = visualize.calc_pr_mean( + curves[model_prefix]["all"]["server_domain_avg_prc"] = visualize.calc_pr_mean( df_domain_avg.server_val.as_matrix(), server_domain_avg_preds) - results[model_prefix]["all"]["server_domain_avg_roc"] = visualize.calc_roc_mean( + curves[model_prefix]["all"]["server_domain_avg_roc"] = visualize.calc_roc_mean( df_domain_avg.server_val.as_matrix(), server_domain_avg_preds) - - joblib.dump(results, f"{path}/curves.joblib") - - # plot_overall_result() + + joblib.dump(curves, f"{path}/curves.joblib") +import matplotlib + +matplotlib.use("agg") import matplotlib.pyplot as plt diff --git a/visualize.py b/visualize.py index bc96167..22fec29 100644 --- a/visualize.py +++ b/visualize.py @@ -1,6 +1,10 @@ import os +import matplotlib + +matplotlib.use("agg") import matplotlib.pyplot as plt + import numpy as np import pandas as pd import seaborn as sns