update plotting function according the test and beta results

This commit is contained in:
René Knaebel 2017-11-10 12:34:21 +01:00
parent c19d649bc4
commit 3ce385eca6
1 changed files with 7 additions and 6 deletions

13
main.py
View File

@ -629,7 +629,6 @@ def main_beta():
"hits_vt": hits_vt, "hits_trusted": hits_trusted,
}
if "server_pred" in res:
print(res["server_pred"].shape, server_val.value.shape)
server = res["server_pred"] if len(res["server_pred"].shape) == 2 else res["server_pred"].max(axis=1)
val = server_val.value.max(axis=1)
data["server_pred"] = server.flatten()
@ -679,8 +678,8 @@ def main_beta():
client_user_preds.append(df_user.client_pred.as_matrix())
if "server_val" in df.columns:
server_user_preds.append(df_user.server_pred.as_matrix())
logger.info("plot client curves")
logger.info("compute client curves")
curves[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds)
curves[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds)
curves[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(),
@ -689,7 +688,7 @@ def main_beta():
client_user_preds)
if "server_val" in df.columns:
logger.info("plot server curves")
logger.info("compute server curves")
curves[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(),
server_preds)
curves[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(),
@ -701,7 +700,7 @@ def main_beta():
server_user_preds)
if df_server is not None:
logger.info("plot server flow curves")
logger.info("compute server flow curves")
curves[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(),
server_flow_preds)
curves[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(),
@ -727,9 +726,11 @@ import matplotlib.pyplot as plt
def plot_overall_result():
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
path, model_prefix = os.path.split(os.path.normpath(args.model_path))
exists_or_make_path(f"{path}/figs/curves/")
try:
results = joblib.load(f"{path}/curves.joblib")
logger.info("curves successfully loaded")
except Exception:
results = {}