diff --git a/Makefile b/Makefile index 587c8ef..1c3842b 100644 --- a/Makefile +++ b/Makefile @@ -1,38 +1,38 @@ run: - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_1 --epochs 2 --depth small \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_1 --epochs 2 --depth flat1 \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_2 --epochs 2 --depth small \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_2 --epochs 2 --depth flat1 \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_3 --epochs 2 --depth medium \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_3 --epochs 2 --depth deep1 \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_4 --epochs 2 --depth medium \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_4 --epochs 2 --depth deep1 \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_5 --epochs 2 --depth small \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_5 --epochs 2 --depth flat2 \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered --model_output both - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_1 --epochs 2 --depth small \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_1 --epochs 2 --depth flat2 \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_2 --epochs 2 --depth small \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ - --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_2 --epochs 2 --depth flat2 \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ + --dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output client - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_3 --epochs 2 --depth medium \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ - --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_3 --epochs 2 --depth deep1 \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ + --dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_4 --epochs 2 --depth medium \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_4 --epochs 2 --depth deep1 \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client test: diff --git a/arguments.py b/arguments.py index 5d16048..51d0712 100644 --- a/arguments.py +++ b/arguments.py @@ -19,6 +19,12 @@ parser.add_argument("--test", action="store", dest="test_data", parser.add_argument("--model", action="store", dest="model_path", default="results/model_x") +parser.add_argument("--model_src", action="store", dest="model_source", + default="results/model_x") + +parser.add_argument("--model_dest", action="store", dest="model_destination", + default="results/model_x") + parser.add_argument("--models", action="store", dest="model_paths", nargs="+", default=[]) @@ -37,6 +43,9 @@ parser.add_argument("--batch", action="store", dest="batch_size", parser.add_argument("--epochs", action="store", dest="epochs", default=10, type=int) +parser.add_argument("--init_epoch", action="store", dest="initial_epoch", + default=0, type=int) + # parser.add_argument("--samples", action="store", dest="samples", # default=100000, type=int) # @@ -98,7 +107,6 @@ parser.add_argument("--gpu", action="store_true", dest="gpu") parser.add_argument("--new_model", action="store_true", dest="new_model") - def get_model_args(args): return [{ "model_path": model_path, @@ -111,6 +119,7 @@ def get_model_args(args): "future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred") } for model_path in args.model_paths] + def parse(): args = parser.parse_args() args.result_path = os.path.split(os.path.normpath(args.output_prefix))[1] diff --git a/main.py b/main.py index e5588da..fbbbfab 100644 --- a/main.py +++ b/main.py @@ -5,8 +5,8 @@ import os import numpy as np import pandas as pd import tensorflow as tf -from keras.callbacks import CSVLogger, EarlyStopping, LambdaCallback, ModelCheckpoint -from keras.models import Model, load_model +from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint +from keras.models import Model, load_model as load_keras_model import arguments import dataset @@ -86,6 +86,12 @@ def create_model(model, output_type): raise Exception("unknown model output") +def load_model(path, custom_objects=None): + clf = load_keras_model(path, custom_objects) + embd = clf.get_layer("domain_cnn").layer + return embd, clf + + def main_paul_best(): pauls_best_params = models.pauls_networks.best_config main_train(pauls_best_params) @@ -161,10 +167,6 @@ def main_train(param=None): logger.info(f"Generator model with params: {param}") embedding, model, new_model = models.get_models_by_params(param) - callbacks.append(LambdaCallback( - on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model)) - ) - model = create_model(model, args.model_output) new_model = create_model(new_model, args.model_output) @@ -222,6 +224,67 @@ def main_train(param=None): class_weight=custom_class_weights) +def main_retrain(): + source = os.path.join(args.model_source, "clf.h5") + destination = os.path.join(args.model_destination, "clf.h5") + + logger.info(f"Use command line arguments: {args}") + exists_or_make_path(destination) + + domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.train_h5data, + args.train_data, + args.domain_length, + args.window) + logger.info("define callbacks") + callbacks = [] + callbacks.append(ModelCheckpoint(filepath=destination, + monitor='loss', + verbose=False, + save_best_only=True)) + callbacks.append(CSVLogger(args.train_log)) + logger.info(f"Use early stopping: {args.stop_early}") + if args.stop_early: + callbacks.append(EarlyStopping(monitor='val_loss', + patience=5, + verbose=False)) + + server_tr = np.max(server_windows_tr, axis=1) + + if args.class_weights: + logger.info("class weights: compute custom weights") + custom_class_weights = get_custom_class_weights(client_tr.value, server_tr) + logger.info(custom_class_weights) + else: + logger.info("class weights: set default") + custom_class_weights = None + + logger.info(f"Load pretrained model") + embedding, model = load_model(source, custom_objects=models.get_metrics()) + + if args.model_type in ("inter", "staggered"): + server_tr = np.expand_dims(server_windows_tr, 2) + + features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value} + if args.model_output == "both": + labels = {"client": client_tr.value, "server": server_tr} + elif args.model_output == "client": + labels = {"client": client_tr.value} + elif args.model_output == "server": + labels = {"server": server_tr} + else: + raise ValueError("unknown model output") + + logger.info("re-train model") + embedding.summary() + model.summary() + model.fit(features, labels, + batch_size=args.batch_size, + epochs=args.epochs, + callbacks=callbacks, + class_weight=custom_class_weights, + initial_epoch=args.initial_epoch) + + def main_test(): logger.info("start test: load data") domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.test_h5data, @@ -233,7 +296,7 @@ def main_test(): for model_args in get_model_args(args): results = {} logger.info(f"process model {model_args['model_path']}") - clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics()) + embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics()) pred = clf_model.predict([domain_val, flow_val], batch_size=args.batch_size, @@ -248,7 +311,6 @@ def main_test(): else: results["server_pred"] = pred - embd_model = load_model(model_args["embedding_model"], custom_objects=models.get_metrics()) domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1) results["domain_embds"] = domain_embeddings @@ -278,7 +340,7 @@ def main_visualization(): df_paul_user = df_paul.groupby(df_paul.names).max() logger.info("plot model") - model = load_model(args.clf_model, custom_objects=models.get_metrics()) + embd, model = load_model(args.clf_model, custom_objects=models.get_metrics()) visualize.plot_model_as(model, os.path.join(args.model_path, "model.png")) # logger.info("plot training curve") @@ -491,6 +553,16 @@ def main_beta(): visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.png") joblib.dump(results, f"{path}/curves.joblib") + + plot_overall_result() + + +def plot_overall_result(): + path, model_prefix = os.path.split(os.path.normpath(args.output_prefix)) + try: + results = joblib.load(f"{path}/curves.joblib") + except Exception: + results = {} import matplotlib.pyplot as plt x = np.linspace(0, 1, 10000) @@ -500,7 +572,7 @@ def main_beta(): for model_key in results.keys(): ys_mean, ys_std, score = results[model_key][vis] plt.plot(x, ys_mean, label=f"{model_key} - {score:5.4}") - plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1) + plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, alpha=0.2) if vis.endswith("prc"): plt.xlabel('Recall') plt.ylabel('Precision') @@ -516,6 +588,8 @@ def main_beta(): def main(): if "train" == args.mode: main_train() + if "retrain" == args.mode: + main_retrain() if "hyperband" == args.mode: main_hyperband() if "test" == args.mode: @@ -530,6 +604,8 @@ def main(): main_paul_best() if "beta" == args.mode: main_beta() + if "beta_all" == args.mode: + plot_overall_result() if __name__ == "__main__": diff --git a/visualize.py b/visualize.py index 48c833f..e0e140b 100644 --- a/visualize.py +++ b/visualize.py @@ -84,14 +84,17 @@ def calc_pr_mean(y, y_preds): return ys_mean, ys_std, scores_mean +def plot_mean_curve(x, ys, std, score, label): + plt.plot(x, ys, label=f"{label} - {score:5.4}") + plt.fill_between(x, ys - std, ys + std, alpha=0.1) + plt.ylim([0.0, 1.0]) + plt.xlim([0.0, 1.0]) + + def plot_pr_mean(y, y_preds, label=""): x = np.linspace(0, 1, 10000) ys_mean, ys_std, score = calc_pr_mean(y, y_preds) - - plt.plot(x, ys_mean, label=f"{label} - {score:5.4}") - plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1) - plt.ylim([0.0, 1.0]) - plt.xlim([0.0, 1.0]) + plot_mean_curve(x, ys_mean, ys_std, score, label) plt.xlabel('Recall') plt.ylabel('Precision') @@ -142,13 +145,9 @@ def calc_roc_mean(y, y_preds): def plot_roc_mean(y, y_preds, label=""): x = np.linspace(0, 1, 10000) - ys_mean, ys_std, auc_mean = calc_roc_mean(y, y_preds) + ys_mean, ys_std, score = calc_roc_mean(y, y_preds) plt.xscale('log') - plt.ylim([0.0, 1.0]) - plt.xlim([0.0, 1.0]) - - plt.plot(x, ys_mean, label=f"{label} - {auc_mean:5.4}") - plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1) + plot_mean_curve(x, ys_mean, ys_std, score, label) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate')