From e8473048cb195215d55bee48e99d367e3a937fff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Sun, 8 Oct 2017 11:52:10 +0200 Subject: [PATCH] refactor visualization; start plotting server results --- main.py | 116 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 69 insertions(+), 47 deletions(-) diff --git a/main.py b/main.py index a96eb2d..084c441 100644 --- a/main.py +++ b/main.py @@ -342,7 +342,7 @@ def main_visualization(): visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul") visualize.plot_legend() - visualize.plot_save("{}/{}_{}.png".format(model_path, aggregation, curve)) + visualize.plot_save("{}/{}_{}.pdf".format(model_path, aggregation, curve)) _, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data, args.data, @@ -388,17 +388,17 @@ def main_visualization(): # absolute values visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(), - "{}/client_cov.png".format(args.model_path), + "{}/client_cov.pdf".format(args.model_path), normalize=False, title="Client Confusion Matrix") visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(), - "{}/user_cov.png".format(args.model_path), + "{}/user_cov.pdf".format(args.model_path), normalize=False, title="User Confusion Matrix") # normalized visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(), - "{}/client_cov_norm.png".format(args.model_path), + "{}/client_cov_norm.pdf".format(args.model_path), normalize=True, title="Client Confusion Matrix") visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(), - "{}/user_cov_norm.png".format(args.model_path), + "{}/user_cov_norm.pdf".format(args.model_path), normalize=True, title="User Confusion Matrix") plot_embedding(args.model_path, results["domain_embds"], args.data, args.domain_length) @@ -406,7 +406,7 @@ def main_visualization(): def plot_embedding(model_path, domain_embedding, data, domain_length): logger.info("visualize embedding") domain_encs, labels = dataset.load_or_generate_domains(data, domain_length) - visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd") + visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.pdf".format(model_path), method="svd") def main_visualize_all(): @@ -423,6 +423,8 @@ def main_visualize_all(): }) res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3) return res + + dfs = [(model_args["model_name"], load_df(model_args["model_path"])) for model_args in get_model_args(args)] paul = dataset.load_predictions("results/paul/") df_paul = pd.DataFrame(data={ @@ -430,46 +432,38 @@ def main_visualize_all(): "hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten() }) df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3) - df_paul_user = df_paul.groupby(df_paul.names).max() - + + def vis(output_prefix, dfs, df_paul, aggregation, curve): + visualize.plot_clf() + if curve == "prc": + for model_name, df in dfs: + if aggregation == "user": + df = df.groupby(df.names).max() + visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name) + if aggregation == "user": + df_paul = df_paul.groupby(df_paul.names).max() + visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul") + elif curve == "roc": + for model_name, df in dfs: + if aggregation == "user": + df = df.groupby(df.names).max() + visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name) + if aggregation == "user": + df_paul = df_paul.groupby(df_paul.names).max() + visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul") + visualize.plot_legend() + visualize.plot_save("{}_{}_{}.pdf".format(output_prefix, aggregation, curve)) + logger.info("plot pr curves") - visualize.plot_clf() - for model_args in get_model_args(args): - df = load_df(model_args["model_path"]) - visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"]) - visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul") - visualize.plot_legend() - visualize.plot_save(f"{args.output_prefix}_window_client_prc.png") - + vis(args.output_prefix, dfs, df_paul, "window", "prc") logger.info("plot roc curves") - visualize.plot_clf() - for model_args in get_model_args(args): - df = load_df(model_args["model_path"]) - visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"]) - visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul") - visualize.plot_legend() - visualize.plot_save(f"{args.output_prefix}_window_client_roc.png") + vis(args.output_prefix, dfs, df_paul, "window", "roc") logger.info("plot user pr curves") - visualize.plot_clf() - for model_args in get_model_args(args): - df = load_df(model_args["model_path"]) - df = df.groupby(df.names).max() - visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"]) - visualize.plot_precision_recall(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul") - visualize.plot_legend() - visualize.plot_save(f"{args.output_prefix}_user_client_prc.png") - + vis(args.output_prefix, dfs, df_paul, "user", "prc") logger.info("plot user roc curves") - visualize.plot_clf() - for model_args in get_model_args(args): - df = load_df(model_args["model_path"]) - df = df.groupby(df.names).max() - visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"]) - visualize.plot_roc_curve(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul") - visualize.plot_legend() - visualize.plot_save(f"{args.output_prefix}_user_client_roc.png") - + vis(args.output_prefix, dfs, df_paul, "user", "roc") + import joblib @@ -515,7 +509,7 @@ def main_beta(): visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean") visualize.plot_pr_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul") visualize.plot_legend() - visualize.plot_save(f"{args.output_prefix}_window_client_prc_all.png") + visualize.plot_save(f"{args.output_prefix}_window_client_prc_all.pdf") logger.info("plot roc curves") visualize.plot_clf() @@ -529,7 +523,7 @@ def main_beta(): visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean") visualize.plot_roc_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul") visualize.plot_legend() - visualize.plot_save(f"{args.output_prefix}_window_client_roc_all.png") + visualize.plot_save(f"{args.output_prefix}_window_client_roc_all.pdf") logger.info("plot user pr curves") visualize.plot_clf() @@ -544,7 +538,7 @@ def main_beta(): visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean") visualize.plot_pr_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul") visualize.plot_legend() - visualize.plot_save(f"{args.output_prefix}_user_client_prc_all.png") + visualize.plot_save(f"{args.output_prefix}_user_client_prc_all.pdf") logger.info("plot user roc curves") visualize.plot_clf() @@ -557,7 +551,7 @@ def main_beta(): visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean") visualize.plot_roc_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul") visualize.plot_legend() - visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.png") + visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.pdf") joblib.dump(results, f"{path}/curves.joblib") @@ -589,13 +583,13 @@ def plot_overall_result(): plt.ylim([0.0, 1.0]) plt.xlim([0.0, 1.0]) visualize.plot_legend() - visualize.plot_save(f"{path}/{vis}_all.png") + visualize.plot_save(f"{path}/{vis}_all.pdf") for cat, models in results.items(): visualize.plot_clf() visualize.plot_error_bars(models) visualize.plot_legend() - visualize.plot_save(f"{path}/error_bars_{cat}.png") + visualize.plot_save(f"{path}/error_bars_{cat}.pdf") def train_server_only(): @@ -678,6 +672,34 @@ def test_server_only(): dataset.save_predictions(model_args["model_path"], results) +def vis_server(): + def load_model(m, c): + from keras.models import load_model + clf = load_model(m, custom_objects=c) + emdb = clf.layers[1] + return emdb, clf + + domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data( + args.data, + args.data, + args.domain_length, + args.window) + + results = dataset.load_predictions(args.clf_model) + + visualize.plot_clf() + visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server") + visualize.plot_legend() + visualize.plot_save("results/server_model/windows_prc.pdf") + visualize.plot_clf() + visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server") + visualize.plot_legend() + visualize.plot_save("results/server_model/windows_prc.pdf") + visualize.plot_clf() + visualize.plot_roc_curve(server_raw.flatten(), results["server_pred"].flatten(), "server") + visualize.plot_legend() + visualize.plot_save("results/server_model/windows_roc.pdf") + def main(): if "train" == args.mode: main_train()