refactor visualization; start plotting server results
This commit is contained in:
parent
0b26c6125c
commit
e8473048cb
116
main.py
116
main.py
@ -342,7 +342,7 @@ def main_visualization():
|
||||
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
||||
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("{}/{}_{}.png".format(model_path, aggregation, curve))
|
||||
visualize.plot_save("{}/{}_{}.pdf".format(model_path, aggregation, curve))
|
||||
|
||||
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
||||
args.data,
|
||||
@ -388,17 +388,17 @@ def main_visualization():
|
||||
|
||||
# absolute values
|
||||
visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
|
||||
"{}/client_cov.png".format(args.model_path),
|
||||
"{}/client_cov.pdf".format(args.model_path),
|
||||
normalize=False, title="Client Confusion Matrix")
|
||||
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
|
||||
"{}/user_cov.png".format(args.model_path),
|
||||
"{}/user_cov.pdf".format(args.model_path),
|
||||
normalize=False, title="User Confusion Matrix")
|
||||
# normalized
|
||||
visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
|
||||
"{}/client_cov_norm.png".format(args.model_path),
|
||||
"{}/client_cov_norm.pdf".format(args.model_path),
|
||||
normalize=True, title="Client Confusion Matrix")
|
||||
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
|
||||
"{}/user_cov_norm.png".format(args.model_path),
|
||||
"{}/user_cov_norm.pdf".format(args.model_path),
|
||||
normalize=True, title="User Confusion Matrix")
|
||||
plot_embedding(args.model_path, results["domain_embds"], args.data, args.domain_length)
|
||||
|
||||
@ -406,7 +406,7 @@ def main_visualization():
|
||||
def plot_embedding(model_path, domain_embedding, data, domain_length):
|
||||
logger.info("visualize embedding")
|
||||
domain_encs, labels = dataset.load_or_generate_domains(data, domain_length)
|
||||
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd")
|
||||
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.pdf".format(model_path), method="svd")
|
||||
|
||||
|
||||
def main_visualize_all():
|
||||
@ -423,6 +423,8 @@ def main_visualize_all():
|
||||
})
|
||||
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
|
||||
return res
|
||||
|
||||
dfs = [(model_args["model_name"], load_df(model_args["model_path"])) for model_args in get_model_args(args)]
|
||||
|
||||
paul = dataset.load_predictions("results/paul/")
|
||||
df_paul = pd.DataFrame(data={
|
||||
@ -430,46 +432,38 @@ def main_visualize_all():
|
||||
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
|
||||
})
|
||||
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
|
||||
df_paul_user = df_paul.groupby(df_paul.names).max()
|
||||
|
||||
|
||||
def vis(output_prefix, dfs, df_paul, aggregation, curve):
|
||||
visualize.plot_clf()
|
||||
if curve == "prc":
|
||||
for model_name, df in dfs:
|
||||
if aggregation == "user":
|
||||
df = df.groupby(df.names).max()
|
||||
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name)
|
||||
if aggregation == "user":
|
||||
df_paul = df_paul.groupby(df_paul.names).max()
|
||||
visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
||||
elif curve == "roc":
|
||||
for model_name, df in dfs:
|
||||
if aggregation == "user":
|
||||
df = df.groupby(df.names).max()
|
||||
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name)
|
||||
if aggregation == "user":
|
||||
df_paul = df_paul.groupby(df_paul.names).max()
|
||||
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("{}_{}_{}.pdf".format(output_prefix, aggregation, curve))
|
||||
|
||||
logger.info("plot pr curves")
|
||||
visualize.plot_clf()
|
||||
for model_args in get_model_args(args):
|
||||
df = load_df(model_args["model_path"])
|
||||
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
|
||||
visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
|
||||
|
||||
vis(args.output_prefix, dfs, df_paul, "window", "prc")
|
||||
logger.info("plot roc curves")
|
||||
visualize.plot_clf()
|
||||
for model_args in get_model_args(args):
|
||||
df = load_df(model_args["model_path"])
|
||||
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
|
||||
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
|
||||
vis(args.output_prefix, dfs, df_paul, "window", "roc")
|
||||
|
||||
logger.info("plot user pr curves")
|
||||
visualize.plot_clf()
|
||||
for model_args in get_model_args(args):
|
||||
df = load_df(model_args["model_path"])
|
||||
df = df.groupby(df.names).max()
|
||||
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
|
||||
visualize.plot_precision_recall(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
|
||||
|
||||
vis(args.output_prefix, dfs, df_paul, "user", "prc")
|
||||
logger.info("plot user roc curves")
|
||||
visualize.plot_clf()
|
||||
for model_args in get_model_args(args):
|
||||
df = load_df(model_args["model_path"])
|
||||
df = df.groupby(df.names).max()
|
||||
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
|
||||
visualize.plot_roc_curve(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
|
||||
|
||||
vis(args.output_prefix, dfs, df_paul, "user", "roc")
|
||||
|
||||
|
||||
import joblib
|
||||
|
||||
@ -515,7 +509,7 @@ def main_beta():
|
||||
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
|
||||
visualize.plot_pr_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{args.output_prefix}_window_client_prc_all.png")
|
||||
visualize.plot_save(f"{args.output_prefix}_window_client_prc_all.pdf")
|
||||
|
||||
logger.info("plot roc curves")
|
||||
visualize.plot_clf()
|
||||
@ -529,7 +523,7 @@ def main_beta():
|
||||
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
|
||||
visualize.plot_roc_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{args.output_prefix}_window_client_roc_all.png")
|
||||
visualize.plot_save(f"{args.output_prefix}_window_client_roc_all.pdf")
|
||||
|
||||
logger.info("plot user pr curves")
|
||||
visualize.plot_clf()
|
||||
@ -544,7 +538,7 @@ def main_beta():
|
||||
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
|
||||
visualize.plot_pr_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{args.output_prefix}_user_client_prc_all.png")
|
||||
visualize.plot_save(f"{args.output_prefix}_user_client_prc_all.pdf")
|
||||
|
||||
logger.info("plot user roc curves")
|
||||
visualize.plot_clf()
|
||||
@ -557,7 +551,7 @@ def main_beta():
|
||||
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
|
||||
visualize.plot_roc_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.png")
|
||||
visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.pdf")
|
||||
|
||||
joblib.dump(results, f"{path}/curves.joblib")
|
||||
|
||||
@ -589,13 +583,13 @@ def plot_overall_result():
|
||||
plt.ylim([0.0, 1.0])
|
||||
plt.xlim([0.0, 1.0])
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{path}/{vis}_all.png")
|
||||
visualize.plot_save(f"{path}/{vis}_all.pdf")
|
||||
|
||||
for cat, models in results.items():
|
||||
visualize.plot_clf()
|
||||
visualize.plot_error_bars(models)
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{path}/error_bars_{cat}.png")
|
||||
visualize.plot_save(f"{path}/error_bars_{cat}.pdf")
|
||||
|
||||
|
||||
def train_server_only():
|
||||
@ -678,6 +672,34 @@ def test_server_only():
|
||||
dataset.save_predictions(model_args["model_path"], results)
|
||||
|
||||
|
||||
def vis_server():
|
||||
def load_model(m, c):
|
||||
from keras.models import load_model
|
||||
clf = load_model(m, custom_objects=c)
|
||||
emdb = clf.layers[1]
|
||||
return emdb, clf
|
||||
|
||||
domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data(
|
||||
args.data,
|
||||
args.data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
|
||||
results = dataset.load_predictions(args.clf_model)
|
||||
|
||||
visualize.plot_clf()
|
||||
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("results/server_model/windows_prc.pdf")
|
||||
visualize.plot_clf()
|
||||
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("results/server_model/windows_prc.pdf")
|
||||
visualize.plot_clf()
|
||||
visualize.plot_roc_curve(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("results/server_model/windows_roc.pdf")
|
||||
|
||||
def main():
|
||||
if "train" == args.mode:
|
||||
main_train()
|
||||
|
Loading…
Reference in New Issue
Block a user