refactor test function working on full unfiltered data

This commit is contained in:
René Knaebel 2017-09-08 19:10:23 +02:00
parent edc75f4f44
commit 9a51b6ea34
4 changed files with 114 additions and 89 deletions

View File

@ -66,4 +66,5 @@ hyper:
clean: clean:
rm -r results/test/test* rm -r results/test/test*
rm data/rk_mini.csv.gz_raw.h5
rm data/rk_mini.csv.gz.h5 rm data/rk_mini.csv.gz.h5

View File

@ -105,9 +105,9 @@ def get_model_args(args):
"embedding_model": os.path.join(model_path, "embd.h5"), "embedding_model": os.path.join(model_path, "embd.h5"),
"clf_model": os.path.join(model_path, "clf.h5"), "clf_model": os.path.join(model_path, "clf.h5"),
"train_log": os.path.join(model_path, "train.log.csv"), "train_log": os.path.join(model_path, "train.log.csv"),
"train_h5data": args.train_data + ".h5", "train_h5data": args.train_data,
"test_h5data": args.test_data + ".h5", "test_h5data": args.test_data,
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred.h5") "future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred")
} for model_path in args.model_paths] } for model_path in args.model_paths]
def parse(): def parse():
@ -115,7 +115,7 @@ def parse():
args.embedding_model = os.path.join(args.model_path, "embd.h5") args.embedding_model = os.path.join(args.model_path, "embd.h5")
args.clf_model = os.path.join(args.model_path, "clf.h5") args.clf_model = os.path.join(args.model_path, "clf.h5")
args.train_log = os.path.join(args.model_path, "train.log.csv") args.train_log = os.path.join(args.model_path, "train.log.csv")
args.train_h5data = args.train_data + ".h5" args.train_h5data = args.train_data
args.test_h5data = args.test_data + ".h5" args.test_h5data = args.test_data
args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred.h5") args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred")
return args return args

View File

@ -4,6 +4,7 @@ import string
from multiprocessing import Pool from multiprocessing import Pool
import h5py import h5py
import joblib
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
@ -139,14 +140,18 @@ def create_raw_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=
def store_h5dataset(path, data: dict): def store_h5dataset(path, data: dict):
f = h5py.File(path, "w") f = h5py.File(path + ".h5", "w")
for key, val in data.items(): for key, val in data.items():
f.create_dataset(key, data=val) f.create_dataset(key, data=val)
f.close() f.close()
def check_h5dataset(path):
return open(path + ".h5", "r")
def load_h5dataset(path): def load_h5dataset(path):
f = h5py.File(path, "r") f = h5py.File(path + ".h5", "r")
data = {} data = {}
for k in f.keys(): for k in f.keys():
data[k] = f[k] data[k] = f[k]
@ -225,17 +230,17 @@ def get_flow_per_user(df):
def load_or_generate_h5data(h5data, train_data, domain_length, window_size): def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
char_dict = get_character_dict()
logger.info(f"check for h5data {h5data}") logger.info(f"check for h5data {h5data}")
try: try:
open(h5data, "r") check_h5dataset(h5data)
except FileNotFoundError: except FileNotFoundError:
logger.info("h5 data not found - load csv file") logger.info("load raw training dataset")
user_flow_df = get_user_flow_data(train_data) domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data + "_raw", train_data,
logger.info("create training dataset") domain_length, window_size)
domain, flow, name, client, server = create_dataset_from_flows(user_flow_df, char_dict, logger.info("filter training dataset")
max_len=domain_length, domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value,
window_size=window_size) name.value, hits.value,
trusted_hits.value, server.value)
logger.info("store training dataset as h5 file") logger.info("store training dataset as h5 file")
data = { data = {
"domain": domain.astype(np.int8), "domain": domain.astype(np.int8),
@ -250,6 +255,32 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
return data["domain"], data["flow"], data["name"], data["client"], data["server"] return data["domain"], data["flow"], data["name"], data["client"], data["server"]
def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size):
char_dict = get_character_dict()
logger.info(f"check for h5data {h5data}")
try:
check_h5dataset(h5data)
except FileNotFoundError:
logger.info("h5 data not found - load csv file")
user_flow_df = get_user_flow_data(train_data)
logger.info("create raw training dataset")
domain, flow, name, hits, trusted_hits, server = create_raw_dataset_from_flows(user_flow_df, char_dict,
domain_length, window_size)
logger.info("store raw training dataset as h5 file")
data = {
"domain": domain.astype(np.int8),
"flow": flow,
"name": name,
"hits_vt": hits.astype(np.int8),
"hits_trusted": hits.astype(np.int8),
"server": server.astype(np.bool)
}
store_h5dataset(h5data, data)
logger.info("load h5 dataset")
data = load_h5dataset(h5data)
return data["domain"], data["flow"], data["name"], data["hits_vt"], data["hits_trusted"], data["server"]
def generate_names(train_data, window_size): def generate_names(train_data, window_size):
user_flow_df = get_user_flow_data(train_data) user_flow_df = get_user_flow_data(train_data)
with Pool() as pool: with Pool() as pool:
@ -291,13 +322,9 @@ def load_or_generate_domains(train_data, domain_length):
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix().astype(bool) return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix().astype(bool)
def save_predictions(path, c_pred, s_pred): def save_predictions(path, results):
f = h5py.File(path, "w") joblib.dump(results, path + "/results.joblib", compress=3)
f.create_dataset("client", data=c_pred)
f.create_dataset("server", data=s_pred)
f.close()
def load_predictions(path): def load_predictions(path):
f = h5py.File(path, "r") return joblib.load(path + "/results.joblib")
return f["client"], f["server"]

129
main.py
View File

@ -1,13 +1,12 @@
import json import json
import logging import logging
import os import os
import joblib
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import tensorflow as tf import tensorflow as tf
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
from keras.models import load_model, Model from keras.models import Model, load_model
import arguments import arguments
import dataset import dataset
@ -15,9 +14,8 @@ import hyperband
import models import models
# create logger # create logger
import visualize import visualize
from dataset import load_or_generate_h5data
from utils import exists_or_make_path, get_custom_class_weights
from arguments import get_model_args from arguments import get_model_args
from utils import exists_or_make_path, get_custom_class_weights
logger = logging.getLogger('logger') logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -115,8 +113,9 @@ def main_hyperband():
} }
logger.info("create training dataset") logger.info("create training dataset")
domain_tr, flow_tr, name_tr, client_tr, server_tr = load_or_generate_h5data(args.train_h5data, args.train_data, domain_tr, flow_tr, name_tr, client_tr, server_tr = dataset.load_or_generate_h5data(args.train_h5data,
args.domain_length, args.window) args.train_data,
args.domain_length, args.window)
hp = hyperband.Hyperband(params, hp = hyperband.Hyperband(params,
[domain_tr, flow_tr], [domain_tr, flow_tr],
[client_tr, server_tr]) [client_tr, server_tr])
@ -129,10 +128,10 @@ def main_train(param=None):
exists_or_make_path(args.model_path) exists_or_make_path(args.model_path)
logger.info(f"Use command line arguments: {args}") logger.info(f"Use command line arguments: {args}")
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = load_or_generate_h5data(args.train_h5data, domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.train_h5data,
args.train_data, args.train_data,
args.domain_length, args.domain_length,
args.window) args.window)
logger.info("define callbacks") logger.info("define callbacks")
callbacks = [] callbacks = []
callbacks.append(ModelCheckpoint(filepath=args.clf_model, callbacks.append(ModelCheckpoint(filepath=args.clf_model,
@ -245,12 +244,12 @@ def main_train(param=None):
def main_test(): def main_test():
logger.info("start test: load data") logger.info("start test: load data")
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.test_h5data,
args.test_data, args.test_data,
args.domain_length, args.domain_length,
args.window) args.window)
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) domain_encs, _ = dataset.load_or_generate_domains(args.test_data, args.domain_length)
for model_args in get_model_args(args): for model_args in get_model_args(args):
results = {} results = {}
logger.info(f"process model {model_args['model_path']}") logger.info(f"process model {model_args['model_path']}")
@ -268,73 +267,67 @@ def main_test():
results["client_pred"] = pred results["client_pred"] = pred
else: else:
results["server_pred"] = pred results["server_pred"] = pred
# dataset.save_predictions(model_args["future_prediction"], c_pred, s_pred)
embd_model = load_model(model_args["embedding_model"]) embd_model = load_model(model_args["embedding_model"])
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1) domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
# np.save(model_args["model_path"] + "/domain_embds.npy", domain_embeddings)
results["domain_embds"] = domain_embeddings results["domain_embds"] = domain_embeddings
joblib.dump(results, model_args["model_path"] + "/results.joblib", compress=3)
dataset.save_predictions(model_args["model_path"], results)
def main_visualization(): def main_visualization():
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
args.test_data, args.test_data,
args.domain_length, args.domain_length,
args.window) args.window)
# client_val, server_val = client_val.value, server_val.value
client_val = client_val.value client_val = client_val.value
logger.info("plot model") logger.info("plot model")
model = load_model(args.clf_model, custom_objects=models.get_metrics()) model = load_model(args.clf_model, custom_objects=models.get_metrics())
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png")) visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
try: logger.info("plot training curve")
logger.info("plot training curve") logs = pd.read_csv(args.train_log)
logs = pd.read_csv(args.train_log) if "acc" in logs.keys():
if args.model_output == "client": visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path)) elif "client_acc" in logs.keys() and "server_acc" in logs.keys():
else: visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path)) visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path)) else:
except Exception as e: logger.warning("Error while plotting training curves")
logger.warning(f"could not generate training curves: {e}")
results = dataset.load_predictions(args.future_prediction)
client_pred, server_pred = dataset.load_predictions(args.future_prediction) client_pred = results["client_pred"].flatten()
client_pred, server_pred = client_pred.value.flatten(), server_pred.value.flatten()
logger.info("plot pr curve") logger.info("plot pr curve")
visualize.plot_clf() visualize.plot_clf()
visualize.plot_precision_recall(client_val, client_pred, args.model_path) visualize.plot_precision_recall(client_val, client_pred, args.model_path)
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save("{}/window_client_prc.png".format(args.model_path)) visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
# visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
# visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
# visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
logger.info("plot roc curve") logger.info("plot roc curve")
visualize.plot_clf() visualize.plot_clf()
visualize.plot_roc_curve(client_val, client_pred, args.model_path) visualize.plot_roc_curve(client_val, client_pred, args.model_path)
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save("{}/window_client_roc.png".format(args.model_path)) visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}") print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}")
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val}) df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float) user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred}) df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_clf() visualize.plot_clf()
visualize.plot_precision_recall(user_vals, user_preds, args.model_path) visualize.plot_precision_recall(user_vals, user_preds, args.model_path)
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save("{}/user_client_prc.png".format(args.model_path)) visualize.plot_save("{}/user_client_prc.png".format(args.model_path))
visualize.plot_clf() visualize.plot_clf()
visualize.plot_roc_curve(user_vals, user_preds, args.model_path) visualize.plot_roc_curve(user_vals, user_preds, args.model_path)
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save("{}/user_client_roc.png".format(args.model_path)) visualize.plot_save("{}/user_client_roc.png".format(args.model_path))
visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(), visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(),
"{}/client_cov.png".format(args.model_path), "{}/client_cov.png".format(args.model_path),
normalize=False, title="Client Confusion Matrix") normalize=False, title="Client Confusion Matrix")
@ -348,44 +341,48 @@ def main_visualization():
def main_visualize_all(): def main_visualize_all():
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
args.test_data, args.test_data,
args.domain_length, args.domain_length,
args.window) args.window)
logger.info("plot pr curves") logger.info("plot pr curves")
visualize.plot_clf() visualize.plot_clf()
for model_args in get_model_args(args): for model_args in get_model_args(args):
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) results = dataset.load_predictions(model_args["future_prediction"])
visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"]) client_pred = results["client_pred"].flatten()
visualize.plot_precision_recall(client_val, client_pred, model_args["model_path"])
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png") visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
logger.info("plot roc curves") logger.info("plot roc curves")
visualize.plot_clf() visualize.plot_clf()
for model_args in get_model_args(args): for model_args in get_model_args(args):
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) results = dataset.load_predictions(model_args["future_prediction"])
visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"]) client_pred = results["client_pred"].flatten()
visualize.plot_roc_curve(client_val, client_pred, model_args["model_path"])
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png") visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val}) df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float) user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
logger.info("plot user pr curves") logger.info("plot user pr curves")
visualize.plot_clf() visualize.plot_clf()
for model_args in get_model_args(args): for model_args in get_model_args(args):
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) results = dataset.load_predictions(model_args["future_prediction"])
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()}) client_pred = results["client_pred"].flatten()
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"]) visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"])
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png") visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
logger.info("plot user roc curves") logger.info("plot user roc curves")
visualize.plot_clf() visualize.plot_clf()
for model_args in get_model_args(args): for model_args in get_model_args(args):
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) results = dataset.load_predictions(model_args["future_prediction"])
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()}) client_pred = results["client_pred"].flatten()
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"]) visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"])
visualize.plot_legend() visualize.plot_legend()