add TSNE embedding; server evaluation visualization
This commit is contained in:
parent
a860f0da34
commit
8b17bd0701
50
dataset.py
50
dataset.py
@ -9,25 +9,28 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
logger = logging.getLogger('logger')
|
logger = logging.getLogger('cisco_logger')
|
||||||
|
|
||||||
chars = dict((char, idx + 1) for (idx, char) in
|
char2idx = dict((char, idx + 1) for (idx, char) in
|
||||||
enumerate(string.ascii_lowercase + string.punctuation + string.digits))
|
enumerate(string.ascii_lowercase + string.punctuation + string.digits))
|
||||||
|
|
||||||
|
idx2char = {v: k for k, v in char2idx.items()}
|
||||||
|
|
||||||
|
|
||||||
def get_character_dict():
|
def get_character_dict():
|
||||||
return chars
|
return char2idx
|
||||||
|
|
||||||
|
|
||||||
def get_vocab_size():
|
def get_vocab_size():
|
||||||
return len(chars) + 1
|
return len(char2idx) + 1
|
||||||
|
|
||||||
|
|
||||||
def encode_char(c):
|
def encode_char(c):
|
||||||
if c in chars:
|
return char2idx.get(c, 0)
|
||||||
return chars[c]
|
|
||||||
else:
|
|
||||||
return 0
|
def decode_char(i):
|
||||||
|
return idx2char.get(i, "")
|
||||||
|
|
||||||
|
|
||||||
encode_char = np.vectorize(encode_char)
|
encode_char = np.vectorize(encode_char)
|
||||||
@ -84,11 +87,12 @@ def get_user_chunks(user_flow, window=10):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_domain_features(domain, vocab: dict, max_length=40):
|
# TODO: DATA CORRUPTION; reverse, 0! to n
|
||||||
|
def get_domain_features(domain, max_length=40):
|
||||||
encoding = np.zeros((max_length,))
|
encoding = np.zeros((max_length,))
|
||||||
for j in range(min(len(domain), max_length)):
|
for j in range(min(len(domain), max_length)):
|
||||||
char = domain[-j] # TODO: why -j -> order reversed for domain url?
|
char = domain[-j] # TODO: why -j -> order reversed for domain url?
|
||||||
encoding[j] = vocab.get(char, 0)
|
encoding[j] = encode_char(char)
|
||||||
return encoding
|
return encoding
|
||||||
|
|
||||||
|
|
||||||
@ -99,13 +103,6 @@ def get_all_flow_features(features):
|
|||||||
return np.log1p(flows)
|
return np.log1p(flows)
|
||||||
|
|
||||||
|
|
||||||
def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10):
|
|
||||||
domain, flow, name, hits, trusted_hits, server = create_raw_dataset_from_flows(user_flow_df, char_dict,
|
|
||||||
max_len, window_size)
|
|
||||||
domain, flow, name, client, server = filter_window_dataset_by_hits(domain, flow, name, hits, trusted_hits, server)
|
|
||||||
return domain, flow, name, client, server
|
|
||||||
|
|
||||||
|
|
||||||
def filter_window_dataset_by_hits(domain, flow, name, hits, trusted_hits, server):
|
def filter_window_dataset_by_hits(domain, flow, name, hits, trusted_hits, server):
|
||||||
# select only 1.0 and 0.0 from training data
|
# select only 1.0 and 0.0 from training data
|
||||||
pos_idx = np.where(np.logical_or(hits == 1.0, trusted_hits >= 1.0))[0]
|
pos_idx = np.where(np.logical_or(hits == 1.0, trusted_hits >= 1.0))[0]
|
||||||
@ -122,7 +119,7 @@ def filter_window_dataset_by_hits(domain, flow, name, hits, trusted_hits, server
|
|||||||
return domain, flow, name, client, server
|
return domain, flow, name, client, server
|
||||||
|
|
||||||
|
|
||||||
def create_raw_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10):
|
def create_raw_dataset_from_flows(user_flow_df, max_len, window_size=10):
|
||||||
logger.info("get chunks from user data frames")
|
logger.info("get chunks from user data frames")
|
||||||
with Pool() as pool:
|
with Pool() as pool:
|
||||||
results = []
|
results = []
|
||||||
@ -131,7 +128,6 @@ def create_raw_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=
|
|||||||
windows = [window for res in results for window in res.get()]
|
windows = [window for res in results for window in res.get()]
|
||||||
logger.info("create training dataset")
|
logger.info("create training dataset")
|
||||||
domain, flow, hits, name, server, trusted_hits = create_dataset_from_windows(chunks=windows,
|
domain, flow, hits, name, server, trusted_hits = create_dataset_from_windows(chunks=windows,
|
||||||
vocab=char_dict,
|
|
||||||
max_len=max_len)
|
max_len=max_len)
|
||||||
# make client labels discrete with 4 different values
|
# make client labels discrete with 4 different values
|
||||||
hits = np.apply_along_axis(lambda x: make_label_discrete(x, 3), 0, np.atleast_2d(hits))
|
hits = np.apply_along_axis(lambda x: make_label_discrete(x, 3), 0, np.atleast_2d(hits))
|
||||||
@ -158,7 +154,7 @@ def load_h5dataset(path):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def create_dataset_from_windows(chunks, vocab, max_len):
|
def create_dataset_from_windows(chunks, max_len):
|
||||||
"""
|
"""
|
||||||
combines domain and feature windows to sequential training data
|
combines domain and feature windows to sequential training data
|
||||||
:param chunks: list of flow feature windows
|
:param chunks: list of flow feature windows
|
||||||
@ -168,7 +164,7 @@ def create_dataset_from_windows(chunks, vocab, max_len):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def get_domain_features_reduced(d):
|
def get_domain_features_reduced(d):
|
||||||
return get_domain_features(d[0], vocab, max_len)
|
return get_domain_features(d[0], max_len)
|
||||||
|
|
||||||
logger.info(" compute domain features")
|
logger.info(" compute domain features")
|
||||||
domain_features = []
|
domain_features = []
|
||||||
@ -257,7 +253,6 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
|||||||
|
|
||||||
def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size):
|
def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size):
|
||||||
h5data = h5data + "_raw"
|
h5data = h5data + "_raw"
|
||||||
char_dict = get_character_dict()
|
|
||||||
logger.info(f"check for h5data {h5data}")
|
logger.info(f"check for h5data {h5data}")
|
||||||
try:
|
try:
|
||||||
check_h5dataset(h5data)
|
check_h5dataset(h5data)
|
||||||
@ -265,8 +260,8 @@ def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size):
|
|||||||
logger.info("h5 data not found - load csv file")
|
logger.info("h5 data not found - load csv file")
|
||||||
user_flow_df = get_user_flow_data(train_data)
|
user_flow_df = get_user_flow_data(train_data)
|
||||||
logger.info("create raw training dataset")
|
logger.info("create raw training dataset")
|
||||||
domain, flow, name, hits, trusted_hits, server = create_raw_dataset_from_flows(user_flow_df, char_dict,
|
domain, flow, name, hits, trusted_hits, server = create_raw_dataset_from_flows(user_flow_df, domain_length,
|
||||||
domain_length, window_size)
|
window_size)
|
||||||
logger.info("store raw training dataset as h5 file")
|
logger.info("store raw training dataset as h5 file")
|
||||||
data = {
|
data = {
|
||||||
"domain": domain.astype(np.int8),
|
"domain": domain.astype(np.int8),
|
||||||
@ -298,7 +293,6 @@ def generate_names(train_data, window_size):
|
|||||||
|
|
||||||
def load_or_generate_domains(train_data, domain_length):
|
def load_or_generate_domains(train_data, domain_length):
|
||||||
fn = f"{train_data}_domains.gz"
|
fn = f"{train_data}_domains.gz"
|
||||||
char_dict = get_character_dict()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_flow_df = pd.read_csv(fn)
|
user_flow_df = pd.read_csv(fn)
|
||||||
@ -317,10 +311,10 @@ def load_or_generate_domains(train_data, domain_length):
|
|||||||
|
|
||||||
user_flow_df.to_csv(fn, compression="gzip")
|
user_flow_df.to_csv(fn, compression="gzip")
|
||||||
|
|
||||||
domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, char_dict, domain_length))
|
domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, domain_length))
|
||||||
domain_encs = np.stack(domain_encs)
|
domain_encs = np.stack(domain_encs)
|
||||||
|
|
||||||
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix().astype(bool)
|
return domain_encs, user_flow_df[["clientLabel", "serverLabel"]].as_matrix().astype(bool)
|
||||||
|
|
||||||
|
|
||||||
def save_predictions(path, results):
|
def save_predictions(path, results):
|
||||||
|
41
fancy.sh
41
fancy.sh
@ -5,25 +5,26 @@ N2=$2
|
|||||||
RESDIR=$3
|
RESDIR=$3
|
||||||
DATADIR=$4
|
DATADIR=$4
|
||||||
|
|
||||||
for ((i = ${N1}; i <= ${N2}; i++))
|
#for ((i = ${N1}; i <= ${N2}; i++))
|
||||||
do
|
#do
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_final_${i} --data ${DATADIR} --model_output client
|
# python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_final_${i} --data ${DATADIR} --model_output client
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_final_${i} --data ${DATADIR} --model_output both
|
# python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_final_${i} --data ${DATADIR} --model_output both
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_inter_${i} --data ${DATADIR} --model_output both
|
# python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_inter_${i} --data ${DATADIR} --model_output both
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_staggered_${i} --data ${DATADIR} --model_output both
|
# python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_staggered_${i} --data ${DATADIR} --model_output both
|
||||||
done
|
#done
|
||||||
|
#
|
||||||
|
#python3 main.py --mode all_fancy --batch 1024 --models ${RESDIR}/client_final_{1..20}/ --data ${DATADIR} --model_output client --out-prefix ${RESDIR}/client_final
|
||||||
|
#python3 main.py --mode all_fancy --batch 1024 --models ${RESDIR}/both_final_{1..20}/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_final
|
||||||
|
#python3 main.py --mode all_fancy --batch 1024 --models ${RESDIR}/both_inter_{1..20}/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_inter
|
||||||
|
#python3 main.py --mode all_fancy --batch 1024 --models ${RESDIR}/both_staggered_{1..20}/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_staggered
|
||||||
|
|
||||||
python3 main.py --mode all_fancy --batch 1024 --models ${RESDIR}/client_final_*/ --data ${DATADIR} --model_output client --out-prefix ${RESDIR}/client_final
|
#python3 main.py --mode beta --batch 1024 --models ${RESDIR}/client_final_{1..20}/ --data ${DATADIR} --model_output client --out-prefix ${RESDIR}/client_final
|
||||||
python3 main.py --mode all_fancy --batch 1024 --models ${RESDIR}/both_final_*/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_final
|
#python3 main.py --mode beta --batch 1024 --models ${RESDIR}/both_final_{1..20}/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_final
|
||||||
python3 main.py --mode all_fancy --batch 1024 --models ${RESDIR}/both_inter_*/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_inter
|
#python3 main.py --mode beta --batch 1024 --models ${RESDIR}/both_inter_{1..20}/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_inter
|
||||||
python3 main.py --mode all_fancy --batch 1024 --models ${RESDIR}/both_staggered_*/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_staggered
|
#python3 main.py --mode beta --batch 1024 --models ${RESDIR}/both_staggered_{1..20}/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_staggered
|
||||||
|
#python3 main.py --mode all_beta --out-prefix ${RESDIR}/both_staggered
|
||||||
|
|
||||||
python3 main.py --mode beta --batch 1024 --models ${RESDIR}/client_final_*/ --data ${DATADIR} --model_output client --out-prefix ${RESDIR}/client_final
|
python3 main.py --mode embedding --batch 1024 --models ${RESDIR}/client_final_{1..20}/ ${RESDIR}/both_final_{1..20}/ \
|
||||||
python3 main.py --mode beta --batch 1024 --models ${RESDIR}/both_final_*/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_final
|
${RESDIR}/both_inter_{1..20}/ ${RESDIR}/both_staggered_{1..20}/ \
|
||||||
python3 main.py --mode beta --batch 1024 --models ${RESDIR}/both_inter_*/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_inter
|
--data ${DATADIR} \
|
||||||
python3 main.py --mode beta --batch 1024 --models ${RESDIR}/both_staggered_*/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_staggered
|
--out-prefix ${RESDIR}/figs/tsne/tsne
|
||||||
|
|
||||||
python3 main.py --mode embedding --batch 1024 --model ${RESDIR}/client_final_*/ --data ${DATADIR} --model_output client --out-prefix --model ${RESDIR}/client_final
|
|
||||||
python3 main.py --mode embedding --batch 1024 --model ${RESDIR}/both_final_*/ --data ${DATADIR} --model_output both --out-prefix --model ${RESDIR}/both_final
|
|
||||||
python3 main.py --mode embedding --batch 1024 --model ${RESDIR}/both_inter_*/ --data ${DATADIR} --model_output both --out-prefix --model ${RESDIR}/both_inter
|
|
||||||
python3 main.py --mode embedding --batch 1024 --model ${RESDIR}/both_staggered_*/ --data ${DATADIR} --model_output both --out-prefix --model ${RESDIR}/both_staggered
|
|
||||||
|
416
main.py
416
main.py
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import joblib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@ -15,35 +16,37 @@ import models
|
|||||||
# create logger
|
# create logger
|
||||||
import visualize
|
import visualize
|
||||||
from arguments import get_model_args
|
from arguments import get_model_args
|
||||||
|
from server import test_server_only, train_server_only
|
||||||
from utils import exists_or_make_path, get_custom_class_weights, get_custom_sample_weights, load_model
|
from utils import exists_or_make_path, get_custom_class_weights, get_custom_sample_weights, load_model
|
||||||
|
|
||||||
logger = logging.getLogger('logger')
|
logger = logging.getLogger('cisco_logger')
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
# create console handler and set level to debug
|
# create console handler and set level to debug
|
||||||
ch = logging.StreamHandler()
|
ch = logging.StreamHandler()
|
||||||
ch.setLevel(logging.DEBUG)
|
ch.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
# create formatter
|
# create formatter
|
||||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
formatter1 = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
# add formatter to ch
|
# add formatter to ch
|
||||||
ch.setFormatter(formatter)
|
ch.setFormatter(formatter1)
|
||||||
|
|
||||||
# add ch to logger
|
# add ch to logger
|
||||||
logger.addHandler(ch)
|
logger.addHandler(ch)
|
||||||
|
|
||||||
ch = logging.FileHandler("info.log")
|
# ch = logging.FileHandler("info.log")
|
||||||
ch.setLevel(logging.DEBUG)
|
# ch.setLevel(logging.DEBUG)
|
||||||
|
#
|
||||||
# create formatter
|
# # create formatter
|
||||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
# formatter2 = logging.Formatter('!! %(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
#
|
||||||
# add formatter to ch
|
# # add formatter to ch
|
||||||
ch.setFormatter(formatter)
|
# ch.setFormatter(formatter2)
|
||||||
|
#
|
||||||
# add ch to logger
|
# # add ch to logger
|
||||||
logger.addHandler(ch)
|
# logger.addHandler(ch)
|
||||||
|
|
||||||
args = arguments.parse()
|
args = arguments.parse()
|
||||||
|
|
||||||
@ -100,7 +103,7 @@ def main_hyperband():
|
|||||||
"flow_features": [3],
|
"flow_features": [3],
|
||||||
"domain_length": [args.domain_length],
|
"domain_length": [args.domain_length],
|
||||||
# model params
|
# model params
|
||||||
"embedding_size": [2 ** x for x in range(3, 7)],
|
"embedding": [2 ** x for x in range(3, 7)],
|
||||||
"filter_embedding": [2 ** x for x in range(1, 10)],
|
"filter_embedding": [2 ** x for x in range(1, 10)],
|
||||||
"kernel_embedding": [1, 3, 5, 7, 9],
|
"kernel_embedding": [1, 3, 5, 7, 9],
|
||||||
"dense_embedding": [2 ** x for x in range(4, 10)],
|
"dense_embedding": [2 ** x for x in range(4, 10)],
|
||||||
@ -119,6 +122,12 @@ def main_hyperband():
|
|||||||
|
|
||||||
if args.model_type in ("inter", "staggered"):
|
if args.model_type in ("inter", "staggered"):
|
||||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||||
|
|
||||||
|
idx = np.random.permutation(len(domain_tr))
|
||||||
|
domain_tr = domain_tr[idx]
|
||||||
|
flow_tr = flow_tr[idx]
|
||||||
|
client_tr = client_tr[idx]
|
||||||
|
server_tr = server_tr[idx]
|
||||||
|
|
||||||
hp = hyperband.Hyperband(param_dist,
|
hp = hyperband.Hyperband(param_dist,
|
||||||
[domain_tr, flow_tr],
|
[domain_tr, flow_tr],
|
||||||
@ -337,8 +346,8 @@ def main_test():
|
|||||||
def main_visualization():
|
def main_visualization():
|
||||||
def plot_model(clf_model, path):
|
def plot_model(clf_model, path):
|
||||||
embd, model = load_model(clf_model, custom_objects=models.get_custom_objects())
|
embd, model = load_model(clf_model, custom_objects=models.get_custom_objects())
|
||||||
visualize.plot_model_as(embd, os.path.join(path, "model_embd.pdf"))
|
visualize.plot_model_as(embd, os.path.join(path, "model_embd.pdf"), shapes=False)
|
||||||
visualize.plot_model_as(model, os.path.join(path, "model_clf.pdf"))
|
visualize.plot_model_as(model, os.path.join(path, "model_clf.pdf"), shapes=False)
|
||||||
|
|
||||||
def vis(model_name, model_path, df, df_paul, aggregation, curve):
|
def vis(model_name, model_path, df, df_paul, aggregation, curve):
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
@ -411,13 +420,15 @@ def main_visualization():
|
|||||||
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
|
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
|
||||||
"{}/user_cov_norm.pdf".format(args.model_path),
|
"{}/user_cov_norm.pdf".format(args.model_path),
|
||||||
normalize=True, title="User Confusion Matrix")
|
normalize=True, title="User Confusion Matrix")
|
||||||
plot_embedding(args.model_path, results["domain_embds"], args.data, args.domain_length)
|
|
||||||
|
|
||||||
|
|
||||||
def plot_embedding(model_path, domain_embedding, data, domain_length):
|
# plot_embedding(args.model_path, results["domain_embds"], args.data, args.domain_length)
|
||||||
logger.info("visualize embedding")
|
|
||||||
domain_encs, labels = dataset.load_or_generate_domains(data, domain_length)
|
|
||||||
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd")
|
# def plot_embedding(model_path, domain_embedding, data, domain_length):
|
||||||
|
# logger.info("visualize embedding")
|
||||||
|
# domain_encs, labels = dataset.load_or_generate_domains(data, domain_length)
|
||||||
|
# visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd")
|
||||||
|
|
||||||
|
|
||||||
def main_visualize_all():
|
def main_visualize_all():
|
||||||
@ -477,11 +488,7 @@ def main_visualize_all():
|
|||||||
|
|
||||||
|
|
||||||
def main_visualize_all_embds():
|
def main_visualize_all_embds():
|
||||||
import matplotlib.pyplot as plt
|
import seaborn as sns
|
||||||
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
|
||||||
args.data,
|
|
||||||
args.domain_length,
|
|
||||||
args.window)
|
|
||||||
|
|
||||||
def load_df(path):
|
def load_df(path):
|
||||||
res = dataset.load_predictions(path)
|
res = dataset.load_predictions(path)
|
||||||
@ -489,42 +496,50 @@ def main_visualize_all_embds():
|
|||||||
|
|
||||||
dfs = [(model_args["model_name"], load_df(model_args["model_path"])) for model_args in get_model_args(args)]
|
dfs = [(model_args["model_name"], load_df(model_args["model_path"])) for model_args in get_model_args(args)]
|
||||||
|
|
||||||
plt.clf()
|
from sklearn.manifold import TSNE
|
||||||
|
|
||||||
from sklearn.decomposition import TruncatedSVD
|
def vis2(domain_embedding, labels):
|
||||||
|
n_levels = 7
|
||||||
def vis(ax, domain_embedding, labels):
|
logger.info(f"reduction for {sub_sample} of {len(domain_embedding)} points")
|
||||||
red = TruncatedSVD(n_components=2)
|
red = TSNE(n_components=2)
|
||||||
# use if draw subset of predictions
|
domains = red.fit_transform(domain_embedding)
|
||||||
idx = np.random.choice(np.arange(len(domain_embedding)), 5000)
|
logger.info("plot kde")
|
||||||
domain_embedding = domain_embedding[idx]
|
sns.kdeplot(domains[labels.sum(axis=1) == 0, 0], domains[labels.sum(axis=1) == 0, 1],
|
||||||
labels = labels[idx]
|
cmap="Blues", label="benign", n_levels=9, alpha=0.45, shade=True, shade_lowest=False)
|
||||||
domain_reduced = red.fit_transform(domain_embedding)
|
sns.kdeplot(domains[labels[:, 1], 0], domains[labels[:, 1], 1],
|
||||||
ax.scatter(domain_reduced[:, 0],
|
cmap="Greens", label="server", n_levels=5, alpha=0.45, shade=True, shade_lowest=False)
|
||||||
domain_reduced[:, 1],
|
sns.kdeplot(domains[labels[:, 0], 0], domains[labels[:, 0], 1],
|
||||||
c=(labels * (1, 2)).sum(1).astype(int),
|
cmap="Reds", label="client", n_levels=5, alpha=0.45, shade=True, shade_lowest=False)
|
||||||
cmap=plt.cm.plasma,
|
|
||||||
s=3,
|
|
||||||
alpha=0.1)
|
|
||||||
|
|
||||||
domain_encs, labels = dataset.load_or_generate_domains(args.data, args.domain_length)
|
domain_encs, labels = dataset.load_or_generate_domains(args.data, args.domain_length)
|
||||||
|
|
||||||
fig, axes = plt.subplots(nrows=5, ncols=4)
|
idx = np.arange(len(labels))
|
||||||
|
client = labels[:, 0]
|
||||||
|
server = labels[:, 1]
|
||||||
|
benign = np.logical_not(np.logical_and(client, server))
|
||||||
|
print(client.sum(), server.sum(), benign.sum())
|
||||||
|
|
||||||
for (model_name, embd), ax in zip(dfs, axes.flat):
|
idx = np.concatenate((
|
||||||
|
np.random.choice(idx[client], 1000),
|
||||||
|
np.random.choice(idx[server], 1000),
|
||||||
|
np.random.choice(idx[benign], 6000)), axis=0)
|
||||||
|
|
||||||
|
print(idx.shape)
|
||||||
|
lls = labels[idx]
|
||||||
|
|
||||||
|
for model_name, embd in dfs:
|
||||||
logger.info(f"plot embedding for {model_name}")
|
logger.info(f"plot embedding for {model_name}")
|
||||||
vis(ax, embd, labels)
|
visualize.plot_clf()
|
||||||
|
embd = embd[idx]
|
||||||
visualize.plot_save("{}_svd.png".format(args.output_prefix, 600))
|
vis2(embd, lls)
|
||||||
|
visualize.plot_save("{}_{}.pdf".format(args.output_prefix, model_name))
|
||||||
import joblib
|
|
||||||
|
|
||||||
|
|
||||||
def main_beta():
|
def main_beta():
|
||||||
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
domain_val, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
||||||
args.data,
|
args.data,
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
||||||
try:
|
try:
|
||||||
results = joblib.load(f"{path}/curves.joblib")
|
results = joblib.load(f"{path}/curves.joblib")
|
||||||
@ -532,82 +547,101 @@ def main_beta():
|
|||||||
results = {}
|
results = {}
|
||||||
results[model_prefix] = {"all": {}}
|
results[model_prefix] = {"all": {}}
|
||||||
|
|
||||||
|
domains = domain_val.value.reshape(-1, 40)
|
||||||
|
domains = np.apply_along_axis(lambda d: "".join(map(dataset.decode_char, d)), 1, domains)
|
||||||
|
|
||||||
def load_df(path):
|
def load_df(path):
|
||||||
|
df_server = None
|
||||||
res = dataset.load_predictions(path)
|
res = dataset.load_predictions(path)
|
||||||
res = pd.DataFrame(data={
|
data = {
|
||||||
"names": name_val, "client_pred": res["client_pred"].flatten(),
|
"names": name_val, "client_pred": res["client_pred"].flatten(),
|
||||||
"hits_vt": hits_vt, "hits_trusted": hits_trusted
|
"hits_vt": hits_vt, "hits_trusted": hits_trusted,
|
||||||
})
|
}
|
||||||
|
if "server_pred" in res:
|
||||||
|
print(res["server_pred"].shape, server_val.value.shape)
|
||||||
|
server = res["server_pred"] if len(res["server_pred"].shape) == 2 else res["server_pred"].max(axis=1)
|
||||||
|
val = server_val.value.max(axis=1)
|
||||||
|
data["server_pred"] = server.flatten()
|
||||||
|
data["server_val"] = val.flatten()
|
||||||
|
|
||||||
|
if res["server_pred"].flatten().shape == server_val.value.flatten().shape:
|
||||||
|
df_server = pd.DataFrame(data={
|
||||||
|
"server_pred": res["server_pred"].flatten(),
|
||||||
|
"domain": domains,
|
||||||
|
"server_val": server_val.value.flatten()
|
||||||
|
})
|
||||||
|
|
||||||
|
res = pd.DataFrame(data=data)
|
||||||
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
|
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
|
||||||
return res
|
|
||||||
|
|
||||||
paul = dataset.load_predictions("results/paul/")
|
return res, df_server
|
||||||
df_paul = pd.DataFrame(data={
|
|
||||||
"names": paul["testNames"].flatten(), "client_pred": paul["testScores"].flatten(),
|
|
||||||
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
|
|
||||||
})
|
|
||||||
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
|
|
||||||
df_paul_user = df_paul.groupby(df_paul.names).max()
|
|
||||||
|
|
||||||
logger.info("plot pr curves")
|
client_preds = []
|
||||||
visualize.plot_clf()
|
server_preds = []
|
||||||
predictions = []
|
server_flow_preds = []
|
||||||
|
client_user_preds = []
|
||||||
|
server_user_preds = []
|
||||||
|
server_domain_preds = []
|
||||||
|
server_domain_avg_preds = []
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
df = load_df(model_args["model_path"])
|
df, df_server = load_df(model_args["model_path"])
|
||||||
predictions.append(df.client_pred.as_matrix())
|
client_preds.append(df.client_pred.as_matrix())
|
||||||
|
if "server_val" in df.columns:
|
||||||
|
server_preds.append(df.server_pred.as_matrix())
|
||||||
|
if df_server is not None:
|
||||||
|
server_flow_preds.append(df_server.server_pred.as_matrix())
|
||||||
|
df_domain = df_server.groupby(df_server.domain).max()
|
||||||
|
server_domain_preds.append(df_domain.server_pred.as_matrix())
|
||||||
|
df_domain_avg = df_server.groupby(df_server.domain).rolling(10).mean()
|
||||||
|
server_domain_avg_preds.append(df_domain_avg.server_pred.as_matrix())
|
||||||
|
|
||||||
results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(),
|
results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(),
|
||||||
df.client_pred.as_matrix().round())
|
df.client_pred.as_matrix().round())
|
||||||
results[model_prefix]["all"]["window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), predictions)
|
df_user = df.groupby(df.names).max()
|
||||||
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
|
client_user_preds.append(df_user.client_pred.as_matrix())
|
||||||
visualize.plot_pr_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
|
if "server_val" in df.columns:
|
||||||
visualize.plot_legend()
|
server_user_preds.append(df_user.server_pred.as_matrix())
|
||||||
visualize.plot_save(f"{args.output_prefix}_window_client_prc_all.pdf")
|
|
||||||
|
|
||||||
logger.info("plot roc curves")
|
logger.info("plot client curves")
|
||||||
visualize.plot_clf()
|
results[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds)
|
||||||
predictions = []
|
results[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds)
|
||||||
for model_args in get_model_args(args):
|
results[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(),
|
||||||
df = load_df(model_args["model_path"])
|
client_user_preds)
|
||||||
predictions.append(df.client_pred.as_matrix())
|
results[model_prefix]["all"]["client_user_roc"] = visualize.calc_roc_mean(df_user.client_val.as_matrix(),
|
||||||
results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(),
|
client_user_preds)
|
||||||
df.client_pred.as_matrix().round())
|
|
||||||
results[model_prefix]["all"]["window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), predictions)
|
|
||||||
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
|
|
||||||
visualize.plot_roc_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
|
|
||||||
visualize.plot_legend()
|
|
||||||
visualize.plot_save(f"{args.output_prefix}_window_client_roc_all.pdf")
|
|
||||||
|
|
||||||
logger.info("plot user pr curves")
|
if "server_val" in df.columns:
|
||||||
visualize.plot_clf()
|
logger.info("plot server curves")
|
||||||
predictions = []
|
results[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(),
|
||||||
for model_args in get_model_args(args):
|
server_preds)
|
||||||
df = load_df(model_args["model_path"])
|
results[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(),
|
||||||
df = df.groupby(df.names).max()
|
server_preds)
|
||||||
predictions.append(df.client_pred.as_matrix())
|
results[model_prefix]["all"]["server_user_prc"] = visualize.calc_pr_mean(df_user.server_val.as_matrix(),
|
||||||
results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(),
|
server_user_preds)
|
||||||
df.client_pred.as_matrix().round())
|
|
||||||
results[model_prefix]["all"]["user_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), predictions)
|
results[model_prefix]["all"]["server_user_roc"] = visualize.calc_roc_mean(df_user.server_val.as_matrix(),
|
||||||
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
|
server_user_preds)
|
||||||
visualize.plot_pr_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
|
|
||||||
visualize.plot_legend()
|
|
||||||
visualize.plot_save(f"{args.output_prefix}_user_client_prc_all.pdf")
|
|
||||||
|
|
||||||
logger.info("plot user roc curves")
|
if df_server is not None:
|
||||||
visualize.plot_clf()
|
logger.info("plot server flow curves")
|
||||||
predictions = []
|
results[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(),
|
||||||
for model_args in get_model_args(args):
|
server_flow_preds)
|
||||||
df = load_df(model_args["model_path"])
|
results[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(),
|
||||||
df = df.groupby(df.names).max()
|
server_flow_preds)
|
||||||
predictions.append(df.client_pred.as_matrix())
|
results[model_prefix]["all"]["server_domain_prc"] = visualize.calc_pr_mean(df_domain.server_val.as_matrix(),
|
||||||
results[model_prefix]["all"]["user_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), predictions)
|
server_domain_preds)
|
||||||
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
|
results[model_prefix]["all"]["server_domain_roc"] = visualize.calc_roc_mean(df_domain.server_val.as_matrix(),
|
||||||
visualize.plot_roc_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
|
server_domain_preds)
|
||||||
visualize.plot_legend()
|
results[model_prefix]["all"]["server_domain_avg_prc"] = visualize.calc_pr_mean(
|
||||||
visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.pdf")
|
df_domain_avg.server_val.as_matrix(),
|
||||||
|
server_domain_avg_preds)
|
||||||
|
results[model_prefix]["all"]["server_domain_avg_roc"] = visualize.calc_roc_mean(
|
||||||
|
df_domain_avg.server_val.as_matrix(),
|
||||||
|
server_domain_avg_preds)
|
||||||
|
|
||||||
joblib.dump(results, f"{path}/curves.joblib")
|
joblib.dump(results, f"{path}/curves.joblib")
|
||||||
|
|
||||||
plot_overall_result()
|
# plot_overall_result()
|
||||||
|
|
||||||
|
|
||||||
def plot_overall_result():
|
def plot_overall_result():
|
||||||
@ -619,12 +653,19 @@ def plot_overall_result():
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
x = np.linspace(0, 1, 10000)
|
x = np.linspace(0, 1, 10000)
|
||||||
for vis in ["window_prc", "window_roc", "user_prc", "user_roc"]:
|
for vis in ["client_window_prc", "client_window_roc", "client_user_prc", "client_user_roc",
|
||||||
|
"server_window_prc", "server_window_roc", "server_user_prc", "server_user_roc",
|
||||||
|
"server_flow_prc", "server_flow_roc", "server_domain_prc", "server_domain_roc",
|
||||||
|
"server_domain_avg_prc", "server_domain_avg_roc"]:
|
||||||
logger.info(f"plot {vis}")
|
logger.info(f"plot {vis}")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
for model_key in results.keys():
|
for model_key in results.keys():
|
||||||
ys_mean, ys_std, score = results[model_key]["all"][vis]
|
if vis not in results[model_key]["all"]:
|
||||||
plt.plot(x, ys_mean, label=f"{model_key} - {score:5.4}")
|
continue
|
||||||
|
if "final" in model_key and vis.startswith("server_flow"):
|
||||||
|
continue
|
||||||
|
ys_mean, ys_std, ys = results[model_key]["all"][vis]
|
||||||
|
plt.plot(x, ys_mean, label=f"{model_key} - {np.mean(ys_mean):5.4} ({np.mean(ys_std):4.3})")
|
||||||
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, alpha=0.2)
|
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, alpha=0.2)
|
||||||
if vis.endswith("prc"):
|
if vis.endswith("prc"):
|
||||||
plt.xlabel('Recall')
|
plt.xlabel('Recall')
|
||||||
@ -632,124 +673,37 @@ def plot_overall_result():
|
|||||||
else:
|
else:
|
||||||
plt.xlabel('False Positive Rate')
|
plt.xlabel('False Positive Rate')
|
||||||
plt.ylabel('True Positive Rate')
|
plt.ylabel('True Positive Rate')
|
||||||
|
plt.xscale('log')
|
||||||
plt.ylim([0.0, 1.0])
|
plt.ylim([0.0, 1.0])
|
||||||
plt.xlim([0.0, 1.0])
|
plt.xlim([0.0, 1.0])
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{path}/{vis}_all.pdf")
|
visualize.plot_save(f"{path}/figs/curves/{vis}_all.pdf")
|
||||||
|
|
||||||
for cat, models in results.items():
|
for vis in ["client_window_prc", "client_window_roc", "client_user_prc", "client_user_roc",
|
||||||
|
"server_window_prc", "server_window_roc", "server_user_prc", "server_user_roc",
|
||||||
|
"server_flow_prc", "server_flow_roc", "server_domain_prc", "server_domain_roc",
|
||||||
|
"server_domain_avg_prc", "server_domain_avg_roc"]:
|
||||||
|
logger.info(f"plot {vis}")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_error_bars(models)
|
for model_key in results.keys():
|
||||||
visualize.plot_legend()
|
if vis not in results[model_key]["all"]:
|
||||||
visualize.plot_save(f"{path}/error_bars_{cat}.pdf")
|
continue
|
||||||
|
if "final" in model_key and vis.startswith("server_flow"):
|
||||||
|
continue
|
||||||
def train_server_only():
|
_, _, ys = results[model_key]["all"][vis]
|
||||||
logger.info(f"Create model path {args.model_path}")
|
for y in ys:
|
||||||
exists_or_make_path(args.model_path)
|
plt.plot(x, y, label=f"{model_key} - {np.mean(y):5.4}")
|
||||||
logger.info(f"Use command line arguments: {args}")
|
if vis.endswith("prc"):
|
||||||
|
plt.xlabel('Recall')
|
||||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
|
plt.ylabel('Precision')
|
||||||
args.data,
|
else:
|
||||||
args.domain_length,
|
plt.xlabel('False Positive Rate')
|
||||||
args.window)
|
plt.ylabel('True Positive Rate')
|
||||||
domain_tr = domain_tr.value.reshape(-1, 40)
|
plt.xscale('log')
|
||||||
flow_tr = flow_tr.value.reshape(-1, 3)
|
plt.ylim([0.0, 1.0])
|
||||||
server_tr = server_windows_tr.value.reshape(-1)
|
plt.xlim([0.0, 1.0])
|
||||||
|
visualize.plot_legend()
|
||||||
logger.info("define callbacks")
|
visualize.plot_save(f"{path}/figs/appendix/{model_key}_{vis}.pdf")
|
||||||
callbacks = []
|
|
||||||
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
|
||||||
monitor='loss',
|
|
||||||
verbose=False,
|
|
||||||
save_best_only=True))
|
|
||||||
callbacks.append(CSVLogger(args.train_log))
|
|
||||||
logger.info(f"Use early stopping: {args.stop_early}")
|
|
||||||
if args.stop_early:
|
|
||||||
callbacks.append(EarlyStopping(monitor='val_loss',
|
|
||||||
patience=5,
|
|
||||||
verbose=False))
|
|
||||||
custom_metrics = models.get_metric_functions()
|
|
||||||
|
|
||||||
model = models.get_server_model_by_params(params=PARAMS)
|
|
||||||
|
|
||||||
features = {"ipt_domains": domain_tr, "ipt_flows": flow_tr}
|
|
||||||
if args.model_output == "both":
|
|
||||||
labels = {"client": client_tr, "server": server_tr}
|
|
||||||
elif args.model_output == "client":
|
|
||||||
labels = {"client": client_tr}
|
|
||||||
elif args.model_output == "server":
|
|
||||||
labels = {"server": server_tr}
|
|
||||||
else:
|
|
||||||
raise ValueError("unknown model output")
|
|
||||||
|
|
||||||
logger.info("compile and train model")
|
|
||||||
logger.info(model.get_config())
|
|
||||||
model.compile(optimizer='adam',
|
|
||||||
loss='binary_crossentropy',
|
|
||||||
metrics=['accuracy'] + custom_metrics)
|
|
||||||
|
|
||||||
model.summary()
|
|
||||||
model.fit(features, labels,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
epochs=args.epochs,
|
|
||||||
callbacks=callbacks)
|
|
||||||
|
|
||||||
|
|
||||||
def test_server_only():
|
|
||||||
logger.info("start test: load data")
|
|
||||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
|
|
||||||
args.data,
|
|
||||||
args.domain_length,
|
|
||||||
args.window)
|
|
||||||
domain_val = domain_val.value.reshape(-1, 40)
|
|
||||||
flow_val = flow_val.value.reshape(-1, 3)
|
|
||||||
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
|
||||||
|
|
||||||
for model_args in get_model_args(args):
|
|
||||||
results = {}
|
|
||||||
logger.info(f"process model {model_args['model_path']}")
|
|
||||||
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects())
|
|
||||||
|
|
||||||
pred = clf_model.predict([domain_val, flow_val],
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
verbose=1)
|
|
||||||
|
|
||||||
results["server_pred"] = pred
|
|
||||||
|
|
||||||
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
|
||||||
results["domain_embds"] = domain_embeddings
|
|
||||||
|
|
||||||
dataset.save_predictions(model_args["model_path"], results)
|
|
||||||
|
|
||||||
|
|
||||||
def vis_server():
|
|
||||||
def load_model(m, c):
|
|
||||||
from keras.models import load_model
|
|
||||||
clf = load_model(m, custom_objects=c)
|
|
||||||
emdb = clf.layers[1]
|
|
||||||
return emdb, clf
|
|
||||||
|
|
||||||
domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data(
|
|
||||||
args.data,
|
|
||||||
args.data,
|
|
||||||
args.domain_length,
|
|
||||||
args.window)
|
|
||||||
|
|
||||||
results = dataset.load_predictions(args.clf_model)
|
|
||||||
|
|
||||||
visualize.plot_clf()
|
|
||||||
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
|
||||||
visualize.plot_legend()
|
|
||||||
visualize.plot_save("results/server_model/windows_prc.pdf")
|
|
||||||
visualize.plot_clf()
|
|
||||||
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
|
||||||
visualize.plot_legend()
|
|
||||||
visualize.plot_save("results/server_model/windows_prc.pdf")
|
|
||||||
visualize.plot_clf()
|
|
||||||
visualize.plot_roc_curve(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
|
||||||
visualize.plot_legend()
|
|
||||||
visualize.plot_save("results/server_model/windows_roc.pdf")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
53
visualize.py
53
visualize.py
@ -3,6 +3,7 @@ import os
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import seaborn as sns
|
||||||
from scipy import interpolate
|
from scipy import interpolate
|
||||||
from sklearn.decomposition import TruncatedSVD
|
from sklearn.decomposition import TruncatedSVD
|
||||||
from sklearn.manifold import TSNE
|
from sklearn.manifold import TSNE
|
||||||
@ -36,13 +37,17 @@ def scores(y_true):
|
|||||||
|
|
||||||
def plot_clf():
|
def plot_clf():
|
||||||
plt.clf()
|
plt.clf()
|
||||||
|
sns.set_context("paper")
|
||||||
|
sns.set_style("white")
|
||||||
|
|
||||||
|
|
||||||
def plot_save(path, dpi=300):
|
def plot_save(path, dpi=600, set_size=True):
|
||||||
plt.title(path)
|
# plt.title(path)
|
||||||
fig = plt.gcf()
|
fig = plt.gcf()
|
||||||
fig.set_size_inches(18.5, 10.5)
|
# fig.suptitle(path)
|
||||||
fig.savefig(path, dpi=dpi)
|
if set_size:
|
||||||
|
fig.set_size_inches(8, 4.5)
|
||||||
|
fig.savefig(path, dpi=dpi, bbox_inches='tight')
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
@ -73,21 +78,7 @@ def plot_precision_recall(y, y_pred, label=""):
|
|||||||
|
|
||||||
|
|
||||||
def calc_pr_mean(y, y_preds):
|
def calc_pr_mean(y, y_preds):
|
||||||
appr = []
|
return calc_metrics_mean(y, y_preds, "prc")
|
||||||
scores = []
|
|
||||||
y = y.flatten()
|
|
||||||
|
|
||||||
for idx, y_pred in enumerate(y_preds):
|
|
||||||
y_pred = y_pred.flatten()
|
|
||||||
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
|
||||||
appr.append(interpolate.interp1d(recall, precision))
|
|
||||||
scores.append(fbeta_score(y, y_pred.round(), 1))
|
|
||||||
x = np.linspace(0, 1, 10000)
|
|
||||||
ys = np.vstack([f(x) for f in appr])
|
|
||||||
ys_mean = ys.mean(axis=0)
|
|
||||||
ys_std = ys.std(axis=0)
|
|
||||||
scores_mean = np.mean(scores)
|
|
||||||
return ys_mean, ys_std, scores_mean
|
|
||||||
|
|
||||||
|
|
||||||
def plot_mean_curve(x, ys, std, score, label):
|
def plot_mean_curve(x, ys, std, score, label):
|
||||||
@ -131,22 +122,26 @@ def plot_roc_curve(mask, prediction, label=""):
|
|||||||
plt.ylabel('True Positive Rate')
|
plt.ylabel('True Positive Rate')
|
||||||
|
|
||||||
|
|
||||||
def calc_roc_mean(y, y_preds):
|
def calc_metrics_mean(y, y_preds, metric):
|
||||||
appr = []
|
appr = []
|
||||||
aucs = []
|
|
||||||
y = y.flatten()
|
y = y.flatten()
|
||||||
|
|
||||||
for idx, y_pred in enumerate(y_preds):
|
for idx, y_pred in enumerate(y_preds):
|
||||||
y_pred = y_pred.flatten()
|
y_pred = y_pred.flatten()
|
||||||
fpr, tpr, thresholds = roc_curve(y, y_pred)
|
if metric == "prc":
|
||||||
appr.append(interpolate.interp1d(fpr, tpr))
|
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
||||||
aucs.append(auc(fpr, tpr))
|
appr.append(interpolate.interp1d(recall, precision))
|
||||||
|
elif metric == "roc":
|
||||||
|
fpr, tpr, thresholds = roc_curve(y, y_pred)
|
||||||
|
appr.append(interpolate.interp1d(fpr, tpr))
|
||||||
x = np.linspace(0, 1, 10000)
|
x = np.linspace(0, 1, 10000)
|
||||||
ys = np.vstack([f(x) for f in appr])
|
ys = np.vstack([f(x) for f in appr])
|
||||||
ys_mean = ys.mean(axis=0)
|
ys_mean = ys.mean(axis=0)
|
||||||
ys_std = ys.std(axis=0)
|
ys_std = ys.std(axis=0)
|
||||||
auc_mean = np.mean(aucs)
|
return ys_mean, ys_std, ys
|
||||||
return ys_mean, ys_std, auc_mean
|
|
||||||
|
|
||||||
|
def calc_roc_mean(y, y_preds):
|
||||||
|
return calc_metrics_mean(y, y_preds, "roc")
|
||||||
|
|
||||||
|
|
||||||
def plot_roc_mean(y, y_preds, label=""):
|
def plot_roc_mean(y, y_preds, label=""):
|
||||||
@ -243,6 +238,6 @@ def plot_embedding(domain_embedding, labels, path, dpi=600, method="svd"):
|
|||||||
plt.savefig(path, dpi=dpi)
|
plt.savefig(path, dpi=dpi)
|
||||||
|
|
||||||
|
|
||||||
def plot_model_as(model, path):
|
def plot_model_as(model, path, shapes=True, layer_names=True):
|
||||||
from keras.utils.vis_utils import plot_model
|
from keras.utils.vis_utils import plot_model
|
||||||
plot_model(model, to_file=path, show_shapes=True, show_layer_names=True)
|
plot_model(model, to_file=path, show_shapes=shapes, show_layer_names=layer_names)
|
||||||
|
Loading…
Reference in New Issue
Block a user