add matplotlib agg mode; update beta vis function according to test results

This commit is contained in:
René Knaebel 2017-11-10 11:04:27 +01:00
parent 4fc2f0c925
commit 27f4d086eb
2 changed files with 46 additions and 36 deletions

74
main.py
View File

@ -610,19 +610,21 @@ def main_beta():
domain_val, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data, domain_val, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.domain_length, args.domain_length,
args.window) args.window)
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix)) path, model_prefix = os.path.split(os.path.normpath(args.model_path))
print(path, model_prefix)
try: try:
results = joblib.load(f"{path}/curves.joblib") curves = joblib.load(f"{path}/curves.joblib")
logger.info(f"load file {path}/curves.joblib successfully")
except Exception: except Exception:
results = {} curves = {}
results[model_prefix] = {"all": {}} logger.info(f"currently {len(curves)} models in file: {curves.keys()}")
curves[model_prefix] = {"all": {}}
domains = domain_val.value.reshape(-1, 40) domains = domain_val.value.reshape(-1, 40)
domains = np.apply_along_axis(lambda d: "".join(map(dataset.decode_char, d)), 1, domains) domains = np.apply_along_axis(lambda d: dataset.decode_domain(d), 1, domains)
def load_df(path): def load_df(res):
df_server = None df_server = None
res = dataset.load_predictions(path)
data = { data = {
"names": name_val, "client_pred": res["client_pred"].flatten(), "names": name_val, "client_pred": res["client_pred"].flatten(),
"hits_vt": hits_vt, "hits_trusted": hits_trusted, "hits_vt": hits_vt, "hits_trusted": hits_trusted,
@ -646,6 +648,9 @@ def main_beta():
return res, df_server return res, df_server
res = dataset.load_predictions(path)
model_keys = sorted(filter(lambda x: x.startswith("clf"), res.keys()), key=lambda x: int(x[4:-3]))
client_preds = [] client_preds = []
server_preds = [] server_preds = []
server_flow_preds = [] server_flow_preds = []
@ -653,8 +658,8 @@ def main_beta():
server_user_preds = [] server_user_preds = []
server_domain_preds = [] server_domain_preds = []
server_domain_avg_preds = [] server_domain_avg_preds = []
for model_args in get_model_args(args): for model_name in model_keys:
df, df_server = load_df(model_args["model_path"]) df, df_server = load_df(res[model_name])
client_preds.append(df.client_pred.as_matrix()) client_preds.append(df.client_pred.as_matrix())
if "server_val" in df.columns: if "server_val" in df.columns:
server_preds.append(df.server_pred.as_matrix()) server_preds.append(df.server_pred.as_matrix())
@ -665,55 +670,56 @@ def main_beta():
df_domain_avg = df_server.groupby(df_server.domain).rolling(10).mean() df_domain_avg = df_server.groupby(df_server.domain).rolling(10).mean()
server_domain_avg_preds.append(df_domain_avg.server_pred.as_matrix()) server_domain_avg_preds.append(df_domain_avg.server_pred.as_matrix())
results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(), curves[model_prefix][model_name] = confusion_matrix(df.client_val.as_matrix(),
df.client_pred.as_matrix().round()) df.client_pred.as_matrix().round())
df_user = df.groupby(df.names).max() df_user = df.groupby(df.names).max()
client_user_preds.append(df_user.client_pred.as_matrix()) client_user_preds.append(df_user.client_pred.as_matrix())
if "server_val" in df.columns: if "server_val" in df.columns:
server_user_preds.append(df_user.server_pred.as_matrix()) server_user_preds.append(df_user.server_pred.as_matrix())
logger.info("plot client curves") logger.info("plot client curves")
results[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds) curves[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds)
results[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds) curves[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds)
results[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(), curves[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(),
client_user_preds)
curves[model_prefix]["all"]["client_user_roc"] = visualize.calc_roc_mean(df_user.client_val.as_matrix(),
client_user_preds) client_user_preds)
results[model_prefix]["all"]["client_user_roc"] = visualize.calc_roc_mean(df_user.client_val.as_matrix(),
client_user_preds)
if "server_val" in df.columns: if "server_val" in df.columns:
logger.info("plot server curves") logger.info("plot server curves")
results[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(), curves[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(),
server_preds)
curves[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(),
server_preds) server_preds)
results[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(), curves[model_prefix]["all"]["server_user_prc"] = visualize.calc_pr_mean(df_user.server_val.as_matrix(),
server_preds) server_user_preds)
results[model_prefix]["all"]["server_user_prc"] = visualize.calc_pr_mean(df_user.server_val.as_matrix(),
server_user_preds)
results[model_prefix]["all"]["server_user_roc"] = visualize.calc_roc_mean(df_user.server_val.as_matrix(), curves[model_prefix]["all"]["server_user_roc"] = visualize.calc_roc_mean(df_user.server_val.as_matrix(),
server_user_preds) server_user_preds)
if df_server is not None: if df_server is not None:
logger.info("plot server flow curves") logger.info("plot server flow curves")
results[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(), curves[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(),
server_flow_preds)
curves[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(),
server_flow_preds) server_flow_preds)
results[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(), curves[model_prefix]["all"]["server_domain_prc"] = visualize.calc_pr_mean(df_domain.server_val.as_matrix(),
server_flow_preds) server_domain_preds)
results[model_prefix]["all"]["server_domain_prc"] = visualize.calc_pr_mean(df_domain.server_val.as_matrix(), curves[model_prefix]["all"]["server_domain_roc"] = visualize.calc_roc_mean(df_domain.server_val.as_matrix(),
server_domain_preds) server_domain_preds)
results[model_prefix]["all"]["server_domain_roc"] = visualize.calc_roc_mean(df_domain.server_val.as_matrix(), curves[model_prefix]["all"]["server_domain_avg_prc"] = visualize.calc_pr_mean(
server_domain_preds)
results[model_prefix]["all"]["server_domain_avg_prc"] = visualize.calc_pr_mean(
df_domain_avg.server_val.as_matrix(), df_domain_avg.server_val.as_matrix(),
server_domain_avg_preds) server_domain_avg_preds)
results[model_prefix]["all"]["server_domain_avg_roc"] = visualize.calc_roc_mean( curves[model_prefix]["all"]["server_domain_avg_roc"] = visualize.calc_roc_mean(
df_domain_avg.server_val.as_matrix(), df_domain_avg.server_val.as_matrix(),
server_domain_avg_preds) server_domain_avg_preds)
joblib.dump(results, f"{path}/curves.joblib") joblib.dump(curves, f"{path}/curves.joblib")
# plot_overall_result()
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt

View File

@ -1,6 +1,10 @@
import os import os
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import seaborn as sns import seaborn as sns