refactor visualization; start plotting server results

This commit is contained in:
René Knaebel 2017-10-08 11:52:10 +02:00
parent 0b26c6125c
commit e8473048cb

116
main.py
View File

@ -342,7 +342,7 @@ def main_visualization():
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save("{}/{}_{}.png".format(model_path, aggregation, curve))
visualize.plot_save("{}/{}_{}.pdf".format(model_path, aggregation, curve))
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.data,
@ -388,17 +388,17 @@ def main_visualization():
# absolute values
visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
"{}/client_cov.png".format(args.model_path),
"{}/client_cov.pdf".format(args.model_path),
normalize=False, title="Client Confusion Matrix")
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
"{}/user_cov.png".format(args.model_path),
"{}/user_cov.pdf".format(args.model_path),
normalize=False, title="User Confusion Matrix")
# normalized
visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
"{}/client_cov_norm.png".format(args.model_path),
"{}/client_cov_norm.pdf".format(args.model_path),
normalize=True, title="Client Confusion Matrix")
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
"{}/user_cov_norm.png".format(args.model_path),
"{}/user_cov_norm.pdf".format(args.model_path),
normalize=True, title="User Confusion Matrix")
plot_embedding(args.model_path, results["domain_embds"], args.data, args.domain_length)
@ -406,7 +406,7 @@ def main_visualization():
def plot_embedding(model_path, domain_embedding, data, domain_length):
logger.info("visualize embedding")
domain_encs, labels = dataset.load_or_generate_domains(data, domain_length)
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd")
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.pdf".format(model_path), method="svd")
def main_visualize_all():
@ -423,6 +423,8 @@ def main_visualize_all():
})
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
return res
dfs = [(model_args["model_name"], load_df(model_args["model_path"])) for model_args in get_model_args(args)]
paul = dataset.load_predictions("results/paul/")
df_paul = pd.DataFrame(data={
@ -430,46 +432,38 @@ def main_visualize_all():
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
})
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
df_paul_user = df_paul.groupby(df_paul.names).max()
def vis(output_prefix, dfs, df_paul, aggregation, curve):
visualize.plot_clf()
if curve == "prc":
for model_name, df in dfs:
if aggregation == "user":
df = df.groupby(df.names).max()
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name)
if aggregation == "user":
df_paul = df_paul.groupby(df_paul.names).max()
visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
elif curve == "roc":
for model_name, df in dfs:
if aggregation == "user":
df = df.groupby(df.names).max()
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name)
if aggregation == "user":
df_paul = df_paul.groupby(df_paul.names).max()
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save("{}_{}_{}.pdf".format(output_prefix, aggregation, curve))
logger.info("plot pr curves")
visualize.plot_clf()
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
vis(args.output_prefix, dfs, df_paul, "window", "prc")
logger.info("plot roc curves")
visualize.plot_clf()
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
vis(args.output_prefix, dfs, df_paul, "window", "roc")
logger.info("plot user pr curves")
visualize.plot_clf()
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
df = df.groupby(df.names).max()
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
visualize.plot_precision_recall(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
vis(args.output_prefix, dfs, df_paul, "user", "prc")
logger.info("plot user roc curves")
visualize.plot_clf()
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
df = df.groupby(df.names).max()
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
visualize.plot_roc_curve(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
vis(args.output_prefix, dfs, df_paul, "user", "roc")
import joblib
@ -515,7 +509,7 @@ def main_beta():
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_pr_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_prc_all.png")
visualize.plot_save(f"{args.output_prefix}_window_client_prc_all.pdf")
logger.info("plot roc curves")
visualize.plot_clf()
@ -529,7 +523,7 @@ def main_beta():
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_roc_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_roc_all.png")
visualize.plot_save(f"{args.output_prefix}_window_client_roc_all.pdf")
logger.info("plot user pr curves")
visualize.plot_clf()
@ -544,7 +538,7 @@ def main_beta():
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_pr_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_prc_all.png")
visualize.plot_save(f"{args.output_prefix}_user_client_prc_all.pdf")
logger.info("plot user roc curves")
visualize.plot_clf()
@ -557,7 +551,7 @@ def main_beta():
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_roc_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.png")
visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.pdf")
joblib.dump(results, f"{path}/curves.joblib")
@ -589,13 +583,13 @@ def plot_overall_result():
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
visualize.plot_legend()
visualize.plot_save(f"{path}/{vis}_all.png")
visualize.plot_save(f"{path}/{vis}_all.pdf")
for cat, models in results.items():
visualize.plot_clf()
visualize.plot_error_bars(models)
visualize.plot_legend()
visualize.plot_save(f"{path}/error_bars_{cat}.png")
visualize.plot_save(f"{path}/error_bars_{cat}.pdf")
def train_server_only():
@ -678,6 +672,34 @@ def test_server_only():
dataset.save_predictions(model_args["model_path"], results)
def vis_server():
def load_model(m, c):
from keras.models import load_model
clf = load_model(m, custom_objects=c)
emdb = clf.layers[1]
return emdb, clf
domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data(
args.data,
args.data,
args.domain_length,
args.window)
results = dataset.load_predictions(args.clf_model)
visualize.plot_clf()
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
visualize.plot_legend()
visualize.plot_save("results/server_model/windows_prc.pdf")
visualize.plot_clf()
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
visualize.plot_legend()
visualize.plot_save("results/server_model/windows_prc.pdf")
visualize.plot_clf()
visualize.plot_roc_curve(server_raw.flatten(), results["server_pred"].flatten(), "server")
visualize.plot_legend()
visualize.plot_save("results/server_model/windows_roc.pdf")
def main():
if "train" == args.mode:
main_train()