refactor visualization; start plotting server results
This commit is contained in:
parent
0b26c6125c
commit
e8473048cb
116
main.py
116
main.py
@ -342,7 +342,7 @@ def main_visualization():
|
|||||||
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
||||||
|
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/{}_{}.png".format(model_path, aggregation, curve))
|
visualize.plot_save("{}/{}_{}.pdf".format(model_path, aggregation, curve))
|
||||||
|
|
||||||
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
||||||
args.data,
|
args.data,
|
||||||
@ -388,17 +388,17 @@ def main_visualization():
|
|||||||
|
|
||||||
# absolute values
|
# absolute values
|
||||||
visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
|
visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
|
||||||
"{}/client_cov.png".format(args.model_path),
|
"{}/client_cov.pdf".format(args.model_path),
|
||||||
normalize=False, title="Client Confusion Matrix")
|
normalize=False, title="Client Confusion Matrix")
|
||||||
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
|
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
|
||||||
"{}/user_cov.png".format(args.model_path),
|
"{}/user_cov.pdf".format(args.model_path),
|
||||||
normalize=False, title="User Confusion Matrix")
|
normalize=False, title="User Confusion Matrix")
|
||||||
# normalized
|
# normalized
|
||||||
visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
|
visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
|
||||||
"{}/client_cov_norm.png".format(args.model_path),
|
"{}/client_cov_norm.pdf".format(args.model_path),
|
||||||
normalize=True, title="Client Confusion Matrix")
|
normalize=True, title="Client Confusion Matrix")
|
||||||
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
|
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
|
||||||
"{}/user_cov_norm.png".format(args.model_path),
|
"{}/user_cov_norm.pdf".format(args.model_path),
|
||||||
normalize=True, title="User Confusion Matrix")
|
normalize=True, title="User Confusion Matrix")
|
||||||
plot_embedding(args.model_path, results["domain_embds"], args.data, args.domain_length)
|
plot_embedding(args.model_path, results["domain_embds"], args.data, args.domain_length)
|
||||||
|
|
||||||
@ -406,7 +406,7 @@ def main_visualization():
|
|||||||
def plot_embedding(model_path, domain_embedding, data, domain_length):
|
def plot_embedding(model_path, domain_embedding, data, domain_length):
|
||||||
logger.info("visualize embedding")
|
logger.info("visualize embedding")
|
||||||
domain_encs, labels = dataset.load_or_generate_domains(data, domain_length)
|
domain_encs, labels = dataset.load_or_generate_domains(data, domain_length)
|
||||||
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd")
|
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.pdf".format(model_path), method="svd")
|
||||||
|
|
||||||
|
|
||||||
def main_visualize_all():
|
def main_visualize_all():
|
||||||
@ -423,6 +423,8 @@ def main_visualize_all():
|
|||||||
})
|
})
|
||||||
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
|
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
dfs = [(model_args["model_name"], load_df(model_args["model_path"])) for model_args in get_model_args(args)]
|
||||||
|
|
||||||
paul = dataset.load_predictions("results/paul/")
|
paul = dataset.load_predictions("results/paul/")
|
||||||
df_paul = pd.DataFrame(data={
|
df_paul = pd.DataFrame(data={
|
||||||
@ -430,46 +432,38 @@ def main_visualize_all():
|
|||||||
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
|
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
|
||||||
})
|
})
|
||||||
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
|
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
|
||||||
df_paul_user = df_paul.groupby(df_paul.names).max()
|
|
||||||
|
def vis(output_prefix, dfs, df_paul, aggregation, curve):
|
||||||
|
visualize.plot_clf()
|
||||||
|
if curve == "prc":
|
||||||
|
for model_name, df in dfs:
|
||||||
|
if aggregation == "user":
|
||||||
|
df = df.groupby(df.names).max()
|
||||||
|
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name)
|
||||||
|
if aggregation == "user":
|
||||||
|
df_paul = df_paul.groupby(df_paul.names).max()
|
||||||
|
visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
||||||
|
elif curve == "roc":
|
||||||
|
for model_name, df in dfs:
|
||||||
|
if aggregation == "user":
|
||||||
|
df = df.groupby(df.names).max()
|
||||||
|
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name)
|
||||||
|
if aggregation == "user":
|
||||||
|
df_paul = df_paul.groupby(df_paul.names).max()
|
||||||
|
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
||||||
|
visualize.plot_legend()
|
||||||
|
visualize.plot_save("{}_{}_{}.pdf".format(output_prefix, aggregation, curve))
|
||||||
|
|
||||||
logger.info("plot pr curves")
|
logger.info("plot pr curves")
|
||||||
visualize.plot_clf()
|
vis(args.output_prefix, dfs, df_paul, "window", "prc")
|
||||||
for model_args in get_model_args(args):
|
|
||||||
df = load_df(model_args["model_path"])
|
|
||||||
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
|
|
||||||
visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
|
||||||
visualize.plot_legend()
|
|
||||||
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
|
|
||||||
|
|
||||||
logger.info("plot roc curves")
|
logger.info("plot roc curves")
|
||||||
visualize.plot_clf()
|
vis(args.output_prefix, dfs, df_paul, "window", "roc")
|
||||||
for model_args in get_model_args(args):
|
|
||||||
df = load_df(model_args["model_path"])
|
|
||||||
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
|
|
||||||
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
|
||||||
visualize.plot_legend()
|
|
||||||
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
|
|
||||||
|
|
||||||
logger.info("plot user pr curves")
|
logger.info("plot user pr curves")
|
||||||
visualize.plot_clf()
|
vis(args.output_prefix, dfs, df_paul, "user", "prc")
|
||||||
for model_args in get_model_args(args):
|
|
||||||
df = load_df(model_args["model_path"])
|
|
||||||
df = df.groupby(df.names).max()
|
|
||||||
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
|
|
||||||
visualize.plot_precision_recall(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul")
|
|
||||||
visualize.plot_legend()
|
|
||||||
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
|
|
||||||
|
|
||||||
logger.info("plot user roc curves")
|
logger.info("plot user roc curves")
|
||||||
visualize.plot_clf()
|
vis(args.output_prefix, dfs, df_paul, "user", "roc")
|
||||||
for model_args in get_model_args(args):
|
|
||||||
df = load_df(model_args["model_path"])
|
|
||||||
df = df.groupby(df.names).max()
|
|
||||||
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
|
|
||||||
visualize.plot_roc_curve(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul")
|
|
||||||
visualize.plot_legend()
|
|
||||||
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
|
|
||||||
|
|
||||||
|
|
||||||
import joblib
|
import joblib
|
||||||
|
|
||||||
@ -515,7 +509,7 @@ def main_beta():
|
|||||||
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
|
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
|
||||||
visualize.plot_pr_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
|
visualize.plot_pr_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_window_client_prc_all.png")
|
visualize.plot_save(f"{args.output_prefix}_window_client_prc_all.pdf")
|
||||||
|
|
||||||
logger.info("plot roc curves")
|
logger.info("plot roc curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
@ -529,7 +523,7 @@ def main_beta():
|
|||||||
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
|
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
|
||||||
visualize.plot_roc_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
|
visualize.plot_roc_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_window_client_roc_all.png")
|
visualize.plot_save(f"{args.output_prefix}_window_client_roc_all.pdf")
|
||||||
|
|
||||||
logger.info("plot user pr curves")
|
logger.info("plot user pr curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
@ -544,7 +538,7 @@ def main_beta():
|
|||||||
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
|
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
|
||||||
visualize.plot_pr_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
|
visualize.plot_pr_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_user_client_prc_all.png")
|
visualize.plot_save(f"{args.output_prefix}_user_client_prc_all.pdf")
|
||||||
|
|
||||||
logger.info("plot user roc curves")
|
logger.info("plot user roc curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
@ -557,7 +551,7 @@ def main_beta():
|
|||||||
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
|
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
|
||||||
visualize.plot_roc_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
|
visualize.plot_roc_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.png")
|
visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.pdf")
|
||||||
|
|
||||||
joblib.dump(results, f"{path}/curves.joblib")
|
joblib.dump(results, f"{path}/curves.joblib")
|
||||||
|
|
||||||
@ -589,13 +583,13 @@ def plot_overall_result():
|
|||||||
plt.ylim([0.0, 1.0])
|
plt.ylim([0.0, 1.0])
|
||||||
plt.xlim([0.0, 1.0])
|
plt.xlim([0.0, 1.0])
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{path}/{vis}_all.png")
|
visualize.plot_save(f"{path}/{vis}_all.pdf")
|
||||||
|
|
||||||
for cat, models in results.items():
|
for cat, models in results.items():
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_error_bars(models)
|
visualize.plot_error_bars(models)
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{path}/error_bars_{cat}.png")
|
visualize.plot_save(f"{path}/error_bars_{cat}.pdf")
|
||||||
|
|
||||||
|
|
||||||
def train_server_only():
|
def train_server_only():
|
||||||
@ -678,6 +672,34 @@ def test_server_only():
|
|||||||
dataset.save_predictions(model_args["model_path"], results)
|
dataset.save_predictions(model_args["model_path"], results)
|
||||||
|
|
||||||
|
|
||||||
|
def vis_server():
|
||||||
|
def load_model(m, c):
|
||||||
|
from keras.models import load_model
|
||||||
|
clf = load_model(m, custom_objects=c)
|
||||||
|
emdb = clf.layers[1]
|
||||||
|
return emdb, clf
|
||||||
|
|
||||||
|
domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data(
|
||||||
|
args.data,
|
||||||
|
args.data,
|
||||||
|
args.domain_length,
|
||||||
|
args.window)
|
||||||
|
|
||||||
|
results = dataset.load_predictions(args.clf_model)
|
||||||
|
|
||||||
|
visualize.plot_clf()
|
||||||
|
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
||||||
|
visualize.plot_legend()
|
||||||
|
visualize.plot_save("results/server_model/windows_prc.pdf")
|
||||||
|
visualize.plot_clf()
|
||||||
|
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
||||||
|
visualize.plot_legend()
|
||||||
|
visualize.plot_save("results/server_model/windows_prc.pdf")
|
||||||
|
visualize.plot_clf()
|
||||||
|
visualize.plot_roc_curve(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
||||||
|
visualize.plot_legend()
|
||||||
|
visualize.plot_save("results/server_model/windows_roc.pdf")
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if "train" == args.mode:
|
if "train" == args.mode:
|
||||||
main_train()
|
main_train()
|
||||||
|
Loading…
Reference in New Issue
Block a user