add first version of model averaging visualization
This commit is contained in:
parent
49ad506a96
commit
b157ca6a19
@ -113,6 +113,7 @@ def get_model_args(args):
|
|||||||
|
|
||||||
def parse():
|
def parse():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
args.result_path = os.path.split(os.path.normpath(args.output_prefix))[1]
|
||||||
args.model_name = os.path.split(os.path.normpath(args.model_path))[1]
|
args.model_name = os.path.split(os.path.normpath(args.model_path))[1]
|
||||||
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
||||||
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
||||||
|
107
main.py
107
main.py
@ -408,6 +408,111 @@ def main_visualize_all():
|
|||||||
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
|
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
|
||||||
|
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
|
||||||
|
|
||||||
|
def main_beta():
|
||||||
|
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
|
||||||
|
args.test_data,
|
||||||
|
args.domain_length,
|
||||||
|
args.window)
|
||||||
|
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
||||||
|
try:
|
||||||
|
results = joblib.load(f"{path}/curves.joblib")
|
||||||
|
except Exception:
|
||||||
|
results = {}
|
||||||
|
results[model_prefix] = {}
|
||||||
|
|
||||||
|
def load_df(path):
|
||||||
|
res = dataset.load_predictions(path)
|
||||||
|
res = pd.DataFrame(data={
|
||||||
|
"names": name_val, "client_pred": res["client_pred"].flatten(),
|
||||||
|
"hits_vt": hits_vt, "hits_trusted": hits_trusted
|
||||||
|
})
|
||||||
|
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
|
||||||
|
return res
|
||||||
|
|
||||||
|
paul = dataset.load_predictions("results/paul/")
|
||||||
|
df_paul = pd.DataFrame(data={
|
||||||
|
"names": paul["testNames"].flatten(), "client_pred": paul["testScores"].flatten(),
|
||||||
|
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
|
||||||
|
})
|
||||||
|
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
|
||||||
|
df_paul_user = df_paul.groupby(df_paul.names).max()
|
||||||
|
|
||||||
|
logger.info("plot pr curves")
|
||||||
|
visualize.plot_clf()
|
||||||
|
predictions = []
|
||||||
|
for model_args in get_model_args(args):
|
||||||
|
df = load_df(model_args["model_path"])
|
||||||
|
predictions.append(df.client_pred.as_matrix())
|
||||||
|
results[model_prefix]["window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), predictions)
|
||||||
|
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
|
||||||
|
visualize.plot_pr_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
|
||||||
|
visualize.plot_legend()
|
||||||
|
visualize.plot_save(f"{args.output_prefix}_window_client_prc_all.png")
|
||||||
|
|
||||||
|
logger.info("plot roc curves")
|
||||||
|
visualize.plot_clf()
|
||||||
|
predictions = []
|
||||||
|
for model_args in get_model_args(args):
|
||||||
|
df = load_df(model_args["model_path"])
|
||||||
|
predictions.append(df.client_pred.as_matrix())
|
||||||
|
results[model_prefix]["window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), predictions)
|
||||||
|
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
|
||||||
|
visualize.plot_roc_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
|
||||||
|
visualize.plot_legend()
|
||||||
|
visualize.plot_save(f"{args.output_prefix}_window_client_roc_all.png")
|
||||||
|
|
||||||
|
logger.info("plot user pr curves")
|
||||||
|
visualize.plot_clf()
|
||||||
|
predictions = []
|
||||||
|
for model_args in get_model_args(args):
|
||||||
|
df = load_df(model_args["model_path"])
|
||||||
|
df = df.groupby(df.names).max()
|
||||||
|
predictions.append(df.client_pred.as_matrix())
|
||||||
|
results[model_prefix]["user_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), predictions)
|
||||||
|
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
|
||||||
|
visualize.plot_pr_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
|
||||||
|
visualize.plot_legend()
|
||||||
|
visualize.plot_save(f"{args.output_prefix}_user_client_prc_all.png")
|
||||||
|
|
||||||
|
logger.info("plot user roc curves")
|
||||||
|
visualize.plot_clf()
|
||||||
|
predictions = []
|
||||||
|
for model_args in get_model_args(args):
|
||||||
|
df = load_df(model_args["model_path"])
|
||||||
|
df = df.groupby(df.names).max()
|
||||||
|
predictions.append(df.client_pred.as_matrix())
|
||||||
|
results[model_prefix]["user_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), predictions)
|
||||||
|
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
|
||||||
|
visualize.plot_roc_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
|
||||||
|
visualize.plot_legend()
|
||||||
|
visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.png")
|
||||||
|
|
||||||
|
joblib.dump(results, f"{path}/curves.joblib")
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
x = np.linspace(0, 1, 10000)
|
||||||
|
for vis in ["window_prc", "window_roc", "user_prc", "user_roc"]:
|
||||||
|
logger.info(f"plot {vis}")
|
||||||
|
visualize.plot_clf()
|
||||||
|
for model_key in results.keys():
|
||||||
|
ys_mean, ys_std, score = results[model_key][vis]
|
||||||
|
plt.plot(x, ys_mean, label=f"{model_key} - {score:5.4}")
|
||||||
|
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
|
||||||
|
if vis.endswith("prc"):
|
||||||
|
plt.xlabel('Recall')
|
||||||
|
plt.ylabel('Precision')
|
||||||
|
else:
|
||||||
|
plt.xlabel('False Positive Rate')
|
||||||
|
plt.ylabel('True Positive Rate')
|
||||||
|
plt.ylim([0.0, 1.0])
|
||||||
|
plt.xlim([0.0, 1.0])
|
||||||
|
visualize.plot_legend()
|
||||||
|
visualize.plot_save(f"{path}/{vis}_all.png")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if "train" == args.mode:
|
if "train" == args.mode:
|
||||||
main_train()
|
main_train()
|
||||||
@ -423,6 +528,8 @@ def main():
|
|||||||
plot_embedding()
|
plot_embedding()
|
||||||
if "paul" == args.mode:
|
if "paul" == args.mode:
|
||||||
main_paul_best()
|
main_paul_best()
|
||||||
|
if "beta" == args.mode:
|
||||||
|
main_beta()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
2
run.sh
2
run.sh
@ -5,7 +5,7 @@ RESDIR=$1
|
|||||||
mkdir -p /tmp/rk/${RESDIR}
|
mkdir -p /tmp/rk/${RESDIR}
|
||||||
DATADIR=$2
|
DATADIR=$2
|
||||||
|
|
||||||
EPOCHS=100
|
EPOCHS=10
|
||||||
|
|
||||||
for output in client both
|
for output in client both
|
||||||
do
|
do
|
||||||
|
@ -10,7 +10,7 @@ RESDIR=$6
|
|||||||
mkdir -p /tmp/rk/${RESDIR}
|
mkdir -p /tmp/rk/${RESDIR}
|
||||||
DATADIR=$7
|
DATADIR=$7
|
||||||
|
|
||||||
EPOCHS=100
|
EPOCHS=10
|
||||||
|
|
||||||
for ((i = ${N1}; i <= ${N2}; i++))
|
for ((i = ${N1}; i <= ${N2}; i++))
|
||||||
do
|
do
|
||||||
@ -25,5 +25,6 @@ do
|
|||||||
--batch 128 \
|
--batch 128 \
|
||||||
--model_output ${OUTPUT} \
|
--model_output ${OUTPUT} \
|
||||||
--type ${TYPE} \
|
--type ${TYPE} \
|
||||||
--depth ${DEPTH}
|
--depth ${DEPTH} \
|
||||||
|
--gpu
|
||||||
done
|
done
|
30
run_model_rene.sh
Normal file
30
run_model_rene.sh
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
|
||||||
|
N1=$1
|
||||||
|
N2=$2
|
||||||
|
OUTPUT=$3
|
||||||
|
DEPTH=$4
|
||||||
|
TYPE=$5
|
||||||
|
RESDIR=$6
|
||||||
|
mkdir -p /tmp/rk/${RESDIR}
|
||||||
|
DATADIR=$7
|
||||||
|
|
||||||
|
EPOCHS=10
|
||||||
|
|
||||||
|
for ((i = ${N1}; i <= ${N2}; i++))
|
||||||
|
do
|
||||||
|
python main.py --mode train \
|
||||||
|
--train ${DATADIR} \
|
||||||
|
--model ${RESDIR}/${OUTPUT}_${TYPE}_${i} \
|
||||||
|
--epochs ${EPOCHS} \
|
||||||
|
--embd 64 \
|
||||||
|
--filter_embd 128 --kernel_embd 5 --dense_embd 64 \
|
||||||
|
--domain_embd 16 \
|
||||||
|
--filter_main 32 --kernel_main 5 --dense_main 256 \
|
||||||
|
--batch 128 \
|
||||||
|
--model_output ${OUTPUT} \
|
||||||
|
--type ${TYPE} \
|
||||||
|
--depth ${DEPTH} \
|
||||||
|
--gpu
|
||||||
|
done
|
59
visualize.py
59
visualize.py
@ -2,6 +2,7 @@ import os
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from scipy import interpolate
|
||||||
from sklearn.decomposition import TruncatedSVD
|
from sklearn.decomposition import TruncatedSVD
|
||||||
from sklearn.manifold import TSNE
|
from sklearn.manifold import TSNE
|
||||||
from sklearn.metrics import (
|
from sklearn.metrics import (
|
||||||
@ -65,13 +66,32 @@ def plot_precision_recall(y, y_pred, label=""):
|
|||||||
plt.xlim([0.0, 1.0])
|
plt.xlim([0.0, 1.0])
|
||||||
|
|
||||||
|
|
||||||
def plot_pr_curves(y, y_preds, label=""):
|
def calc_pr_mean(y, y_preds):
|
||||||
for idx, y in enumerate(y_preds):
|
appr = []
|
||||||
|
scores = []
|
||||||
y = y.flatten()
|
y = y.flatten()
|
||||||
|
|
||||||
|
for idx, y_pred in enumerate(y_preds):
|
||||||
y_pred = y_pred.flatten()
|
y_pred = y_pred.flatten()
|
||||||
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
||||||
score = fbeta_score(y, y_pred.round(), 1)
|
appr.append(interpolate.interp1d(recall, precision))
|
||||||
plt.plot(recall, precision, '--', label=f"{idx}{label} - {score:5.4}")
|
scores.append(fbeta_score(y, y_pred.round(), 1))
|
||||||
|
x = np.linspace(0, 1, 10000)
|
||||||
|
ys = np.vstack([f(x) for f in appr])
|
||||||
|
ys_mean = ys.mean(axis=0)
|
||||||
|
ys_std = ys.std(axis=0)
|
||||||
|
scores_mean = np.mean(scores)
|
||||||
|
return ys_mean, ys_std, scores_mean
|
||||||
|
|
||||||
|
|
||||||
|
def plot_pr_mean(y, y_preds, label=""):
|
||||||
|
x = np.linspace(0, 1, 10000)
|
||||||
|
ys_mean, ys_std, score = calc_pr_mean(y, y_preds)
|
||||||
|
|
||||||
|
plt.plot(x, ys_mean, label=f"{label} - {score:5.4}")
|
||||||
|
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
|
||||||
|
plt.ylim([0.0, 1.0])
|
||||||
|
plt.xlim([0.0, 1.0])
|
||||||
plt.xlabel('Recall')
|
plt.xlabel('Recall')
|
||||||
plt.ylabel('Precision')
|
plt.ylabel('Precision')
|
||||||
|
|
||||||
@ -102,6 +122,37 @@ def plot_roc_curve(mask, prediction, label=""):
|
|||||||
plt.ylabel('True Positive Rate')
|
plt.ylabel('True Positive Rate')
|
||||||
|
|
||||||
|
|
||||||
|
def calc_roc_mean(y, y_preds):
|
||||||
|
appr = []
|
||||||
|
aucs = []
|
||||||
|
y = y.flatten()
|
||||||
|
|
||||||
|
for idx, y_pred in enumerate(y_preds):
|
||||||
|
y_pred = y_pred.flatten()
|
||||||
|
fpr, tpr, thresholds = roc_curve(y, y_pred)
|
||||||
|
appr.append(interpolate.interp1d(fpr, tpr))
|
||||||
|
aucs.append(auc(fpr, tpr))
|
||||||
|
x = np.linspace(0, 1, 10000)
|
||||||
|
ys = np.vstack([f(x) for f in appr])
|
||||||
|
ys_mean = ys.mean(axis=0)
|
||||||
|
ys_std = ys.std(axis=0)
|
||||||
|
auc_mean = np.mean(aucs)
|
||||||
|
return ys_mean, ys_std, auc_mean
|
||||||
|
|
||||||
|
|
||||||
|
def plot_roc_mean(y, y_preds, label=""):
|
||||||
|
x = np.linspace(0, 1, 10000)
|
||||||
|
ys_mean, ys_std, auc_mean = calc_roc_mean(y, y_preds)
|
||||||
|
plt.xscale('log')
|
||||||
|
plt.ylim([0.0, 1.0])
|
||||||
|
plt.xlim([0.0, 1.0])
|
||||||
|
|
||||||
|
plt.plot(x, ys_mean, label=f"{label} - {auc_mean:5.4}")
|
||||||
|
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
|
||||||
|
plt.xlabel('False Positive Rate')
|
||||||
|
plt.ylabel('True Positive Rate')
|
||||||
|
|
||||||
|
|
||||||
def plot_confusion_matrix(y_true, y_pred, path,
|
def plot_confusion_matrix(y_true, y_pred, path,
|
||||||
normalize=False,
|
normalize=False,
|
||||||
classes=("benign", "malicious"),
|
classes=("benign", "malicious"),
|
||||||
|
Loading…
Reference in New Issue
Block a user