From 1ab0108c783bb5fe1d33909788b004938c0be627 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Thu, 7 Sep 2017 17:38:21 +0200 Subject: [PATCH] add window to file names for visualization --- main.py | 8 ++++---- run.sh | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 96bf7e1..53fe16f 100644 --- a/main.py +++ b/main.py @@ -303,14 +303,14 @@ def main_visualization(): logger.info("plot pr curve") visualize.plot_clf() visualize.plot_precision_recall(client_val, client_pred) - visualize.plot_save("{}/client_prc.png".format(args.model_path)) + visualize.plot_save("{}/window_client_prc.png".format(args.model_path)) # visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path)) # visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path)) # visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path)) logger.info("plot roc curve") visualize.plot_clf() visualize.plot_roc_curve(client_val, client_pred) - visualize.plot_save("{}/client_roc.png".format(args.model_path)) + visualize.plot_save("{}/window_client_roc.png".format(args.model_path)) # visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path)) print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}") @@ -351,7 +351,7 @@ def main_visualize_all(): client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"]) visualize.plot_legend() - visualize.plot_save(f"{args.output_prefix}_client_prc.png") + visualize.plot_save(f"{args.output_prefix}_window_client_prc.png") logger.info("plot roc curves") visualize.plot_clf() @@ -359,7 +359,7 @@ def main_visualize_all(): client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"]) visualize.plot_legend() - visualize.plot_save(f"{args.output_prefix}_client_roc.png") + visualize.plot_save(f"{args.output_prefix}_window_client_roc.png") df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val}) user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float) diff --git a/run.sh b/run.sh index 1a74e57..4c502f5 100644 --- a/run.sh +++ b/run.sh @@ -2,7 +2,7 @@ RESDIR=$1 -mkdir -p /tmp/rk/RESDIR +mkdir -p /tmp/rk/${RESDIR} DATADIR=$2 for output in client both