add matplotlib agg mode; update beta vis function according to test results
This commit is contained in:
parent
4fc2f0c925
commit
27f4d086eb
60
main.py
60
main.py
@ -610,19 +610,21 @@ def main_beta():
|
|||||||
domain_val, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
domain_val, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
path, model_prefix = os.path.split(os.path.normpath(args.model_path))
|
||||||
|
print(path, model_prefix)
|
||||||
try:
|
try:
|
||||||
results = joblib.load(f"{path}/curves.joblib")
|
curves = joblib.load(f"{path}/curves.joblib")
|
||||||
|
logger.info(f"load file {path}/curves.joblib successfully")
|
||||||
except Exception:
|
except Exception:
|
||||||
results = {}
|
curves = {}
|
||||||
results[model_prefix] = {"all": {}}
|
logger.info(f"currently {len(curves)} models in file: {curves.keys()}")
|
||||||
|
curves[model_prefix] = {"all": {}}
|
||||||
|
|
||||||
domains = domain_val.value.reshape(-1, 40)
|
domains = domain_val.value.reshape(-1, 40)
|
||||||
domains = np.apply_along_axis(lambda d: "".join(map(dataset.decode_char, d)), 1, domains)
|
domains = np.apply_along_axis(lambda d: dataset.decode_domain(d), 1, domains)
|
||||||
|
|
||||||
def load_df(path):
|
def load_df(res):
|
||||||
df_server = None
|
df_server = None
|
||||||
res = dataset.load_predictions(path)
|
|
||||||
data = {
|
data = {
|
||||||
"names": name_val, "client_pred": res["client_pred"].flatten(),
|
"names": name_val, "client_pred": res["client_pred"].flatten(),
|
||||||
"hits_vt": hits_vt, "hits_trusted": hits_trusted,
|
"hits_vt": hits_vt, "hits_trusted": hits_trusted,
|
||||||
@ -646,6 +648,9 @@ def main_beta():
|
|||||||
|
|
||||||
return res, df_server
|
return res, df_server
|
||||||
|
|
||||||
|
res = dataset.load_predictions(path)
|
||||||
|
model_keys = sorted(filter(lambda x: x.startswith("clf"), res.keys()), key=lambda x: int(x[4:-3]))
|
||||||
|
|
||||||
client_preds = []
|
client_preds = []
|
||||||
server_preds = []
|
server_preds = []
|
||||||
server_flow_preds = []
|
server_flow_preds = []
|
||||||
@ -653,8 +658,8 @@ def main_beta():
|
|||||||
server_user_preds = []
|
server_user_preds = []
|
||||||
server_domain_preds = []
|
server_domain_preds = []
|
||||||
server_domain_avg_preds = []
|
server_domain_avg_preds = []
|
||||||
for model_args in get_model_args(args):
|
for model_name in model_keys:
|
||||||
df, df_server = load_df(model_args["model_path"])
|
df, df_server = load_df(res[model_name])
|
||||||
client_preds.append(df.client_pred.as_matrix())
|
client_preds.append(df.client_pred.as_matrix())
|
||||||
if "server_val" in df.columns:
|
if "server_val" in df.columns:
|
||||||
server_preds.append(df.server_pred.as_matrix())
|
server_preds.append(df.server_pred.as_matrix())
|
||||||
@ -665,7 +670,7 @@ def main_beta():
|
|||||||
df_domain_avg = df_server.groupby(df_server.domain).rolling(10).mean()
|
df_domain_avg = df_server.groupby(df_server.domain).rolling(10).mean()
|
||||||
server_domain_avg_preds.append(df_domain_avg.server_pred.as_matrix())
|
server_domain_avg_preds.append(df_domain_avg.server_pred.as_matrix())
|
||||||
|
|
||||||
results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(),
|
curves[model_prefix][model_name] = confusion_matrix(df.client_val.as_matrix(),
|
||||||
df.client_pred.as_matrix().round())
|
df.client_pred.as_matrix().round())
|
||||||
df_user = df.groupby(df.names).max()
|
df_user = df.groupby(df.names).max()
|
||||||
client_user_preds.append(df_user.client_pred.as_matrix())
|
client_user_preds.append(df_user.client_pred.as_matrix())
|
||||||
@ -673,47 +678,48 @@ def main_beta():
|
|||||||
server_user_preds.append(df_user.server_pred.as_matrix())
|
server_user_preds.append(df_user.server_pred.as_matrix())
|
||||||
|
|
||||||
logger.info("plot client curves")
|
logger.info("plot client curves")
|
||||||
results[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds)
|
curves[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds)
|
||||||
results[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds)
|
curves[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds)
|
||||||
results[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(),
|
curves[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(),
|
||||||
client_user_preds)
|
client_user_preds)
|
||||||
results[model_prefix]["all"]["client_user_roc"] = visualize.calc_roc_mean(df_user.client_val.as_matrix(),
|
curves[model_prefix]["all"]["client_user_roc"] = visualize.calc_roc_mean(df_user.client_val.as_matrix(),
|
||||||
client_user_preds)
|
client_user_preds)
|
||||||
|
|
||||||
if "server_val" in df.columns:
|
if "server_val" in df.columns:
|
||||||
logger.info("plot server curves")
|
logger.info("plot server curves")
|
||||||
results[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(),
|
curves[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(),
|
||||||
server_preds)
|
server_preds)
|
||||||
results[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(),
|
curves[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(),
|
||||||
server_preds)
|
server_preds)
|
||||||
results[model_prefix]["all"]["server_user_prc"] = visualize.calc_pr_mean(df_user.server_val.as_matrix(),
|
curves[model_prefix]["all"]["server_user_prc"] = visualize.calc_pr_mean(df_user.server_val.as_matrix(),
|
||||||
server_user_preds)
|
server_user_preds)
|
||||||
|
|
||||||
results[model_prefix]["all"]["server_user_roc"] = visualize.calc_roc_mean(df_user.server_val.as_matrix(),
|
curves[model_prefix]["all"]["server_user_roc"] = visualize.calc_roc_mean(df_user.server_val.as_matrix(),
|
||||||
server_user_preds)
|
server_user_preds)
|
||||||
|
|
||||||
if df_server is not None:
|
if df_server is not None:
|
||||||
logger.info("plot server flow curves")
|
logger.info("plot server flow curves")
|
||||||
results[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(),
|
curves[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(),
|
||||||
server_flow_preds)
|
server_flow_preds)
|
||||||
results[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(),
|
curves[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(),
|
||||||
server_flow_preds)
|
server_flow_preds)
|
||||||
results[model_prefix]["all"]["server_domain_prc"] = visualize.calc_pr_mean(df_domain.server_val.as_matrix(),
|
curves[model_prefix]["all"]["server_domain_prc"] = visualize.calc_pr_mean(df_domain.server_val.as_matrix(),
|
||||||
server_domain_preds)
|
server_domain_preds)
|
||||||
results[model_prefix]["all"]["server_domain_roc"] = visualize.calc_roc_mean(df_domain.server_val.as_matrix(),
|
curves[model_prefix]["all"]["server_domain_roc"] = visualize.calc_roc_mean(df_domain.server_val.as_matrix(),
|
||||||
server_domain_preds)
|
server_domain_preds)
|
||||||
results[model_prefix]["all"]["server_domain_avg_prc"] = visualize.calc_pr_mean(
|
curves[model_prefix]["all"]["server_domain_avg_prc"] = visualize.calc_pr_mean(
|
||||||
df_domain_avg.server_val.as_matrix(),
|
df_domain_avg.server_val.as_matrix(),
|
||||||
server_domain_avg_preds)
|
server_domain_avg_preds)
|
||||||
results[model_prefix]["all"]["server_domain_avg_roc"] = visualize.calc_roc_mean(
|
curves[model_prefix]["all"]["server_domain_avg_roc"] = visualize.calc_roc_mean(
|
||||||
df_domain_avg.server_val.as_matrix(),
|
df_domain_avg.server_val.as_matrix(),
|
||||||
server_domain_avg_preds)
|
server_domain_avg_preds)
|
||||||
|
|
||||||
joblib.dump(results, f"{path}/curves.joblib")
|
joblib.dump(curves, f"{path}/curves.joblib")
|
||||||
|
|
||||||
# plot_overall_result()
|
|
||||||
|
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("agg")
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("agg")
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
Loading…
Reference in New Issue
Block a user