diff --git a/main.py b/main.py index f561b16..35871ad 100644 --- a/main.py +++ b/main.py @@ -629,7 +629,6 @@ def main_beta(): "hits_vt": hits_vt, "hits_trusted": hits_trusted, } if "server_pred" in res: - print(res["server_pred"].shape, server_val.value.shape) server = res["server_pred"] if len(res["server_pred"].shape) == 2 else res["server_pred"].max(axis=1) val = server_val.value.max(axis=1) data["server_pred"] = server.flatten() @@ -679,8 +678,8 @@ def main_beta(): client_user_preds.append(df_user.client_pred.as_matrix()) if "server_val" in df.columns: server_user_preds.append(df_user.server_pred.as_matrix()) - - logger.info("plot client curves") + + logger.info("compute client curves") curves[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds) curves[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds) curves[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(), @@ -689,7 +688,7 @@ def main_beta(): client_user_preds) if "server_val" in df.columns: - logger.info("plot server curves") + logger.info("compute server curves") curves[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(), server_preds) curves[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(), @@ -701,7 +700,7 @@ def main_beta(): server_user_preds) if df_server is not None: - logger.info("plot server flow curves") + logger.info("compute server flow curves") curves[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(), server_flow_preds) curves[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(), @@ -727,9 +726,11 @@ import matplotlib.pyplot as plt def plot_overall_result(): - path, model_prefix = os.path.split(os.path.normpath(args.output_prefix)) + path, model_prefix = os.path.split(os.path.normpath(args.model_path)) + exists_or_make_path(f"{path}/figs/curves/") try: results = joblib.load(f"{path}/curves.joblib") + logger.info("curves successfully loaded") except Exception: results = {}