add retrain mode

This commit is contained in:
René Knaebel 2017-09-28 12:23:22 +02:00
parent b157ca6a19
commit 090c89a127
4 changed files with 126 additions and 42 deletions

View File

@ -1,38 +1,38 @@
run: run:
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_1 --epochs 2 --depth small \ python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_1 --epochs 2 --depth flat1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_2 --epochs 2 --depth small \ python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_2 --epochs 2 --depth flat1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_3 --epochs 2 --depth medium \ python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_3 --epochs 2 --depth deep1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_4 --epochs 2 --depth medium \ python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_4 --epochs 2 --depth deep1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_5 --epochs 2 --depth small \ python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_5 --epochs 2 --depth flat2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered --model_output both --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_1 --epochs 2 --depth small \ python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_1 --epochs 2 --depth flat2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_2 --epochs 2 --depth small \ python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_2 --epochs 2 --depth flat2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client --dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output client
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_3 --epochs 2 --depth medium \ python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_3 --epochs 2 --depth deep1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client --dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_4 --epochs 2 --depth medium \ python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_4 --epochs 2 --depth deep1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client
test: test:

View File

@ -19,6 +19,12 @@ parser.add_argument("--test", action="store", dest="test_data",
parser.add_argument("--model", action="store", dest="model_path", parser.add_argument("--model", action="store", dest="model_path",
default="results/model_x") default="results/model_x")
parser.add_argument("--model_src", action="store", dest="model_source",
default="results/model_x")
parser.add_argument("--model_dest", action="store", dest="model_destination",
default="results/model_x")
parser.add_argument("--models", action="store", dest="model_paths", nargs="+", parser.add_argument("--models", action="store", dest="model_paths", nargs="+",
default=[]) default=[])
@ -37,6 +43,9 @@ parser.add_argument("--batch", action="store", dest="batch_size",
parser.add_argument("--epochs", action="store", dest="epochs", parser.add_argument("--epochs", action="store", dest="epochs",
default=10, type=int) default=10, type=int)
parser.add_argument("--init_epoch", action="store", dest="initial_epoch",
default=0, type=int)
# parser.add_argument("--samples", action="store", dest="samples", # parser.add_argument("--samples", action="store", dest="samples",
# default=100000, type=int) # default=100000, type=int)
# #
@ -98,7 +107,6 @@ parser.add_argument("--gpu", action="store_true", dest="gpu")
parser.add_argument("--new_model", action="store_true", dest="new_model") parser.add_argument("--new_model", action="store_true", dest="new_model")
def get_model_args(args): def get_model_args(args):
return [{ return [{
"model_path": model_path, "model_path": model_path,
@ -111,6 +119,7 @@ def get_model_args(args):
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred") "future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred")
} for model_path in args.model_paths] } for model_path in args.model_paths]
def parse(): def parse():
args = parser.parse_args() args = parser.parse_args()
args.result_path = os.path.split(os.path.normpath(args.output_prefix))[1] args.result_path = os.path.split(os.path.normpath(args.output_prefix))[1]

96
main.py
View File

@ -5,8 +5,8 @@ import os
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import tensorflow as tf import tensorflow as tf
from keras.callbacks import CSVLogger, EarlyStopping, LambdaCallback, ModelCheckpoint from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
from keras.models import Model, load_model from keras.models import Model, load_model as load_keras_model
import arguments import arguments
import dataset import dataset
@ -86,6 +86,12 @@ def create_model(model, output_type):
raise Exception("unknown model output") raise Exception("unknown model output")
def load_model(path, custom_objects=None):
clf = load_keras_model(path, custom_objects)
embd = clf.get_layer("domain_cnn").layer
return embd, clf
def main_paul_best(): def main_paul_best():
pauls_best_params = models.pauls_networks.best_config pauls_best_params = models.pauls_networks.best_config
main_train(pauls_best_params) main_train(pauls_best_params)
@ -161,10 +167,6 @@ def main_train(param=None):
logger.info(f"Generator model with params: {param}") logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param) embedding, model, new_model = models.get_models_by_params(param)
callbacks.append(LambdaCallback(
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
)
model = create_model(model, args.model_output) model = create_model(model, args.model_output)
new_model = create_model(new_model, args.model_output) new_model = create_model(new_model, args.model_output)
@ -222,6 +224,67 @@ def main_train(param=None):
class_weight=custom_class_weights) class_weight=custom_class_weights)
def main_retrain():
source = os.path.join(args.model_source, "clf.h5")
destination = os.path.join(args.model_destination, "clf.h5")
logger.info(f"Use command line arguments: {args}")
exists_or_make_path(destination)
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.train_h5data,
args.train_data,
args.domain_length,
args.window)
logger.info("define callbacks")
callbacks = []
callbacks.append(ModelCheckpoint(filepath=destination,
monitor='loss',
verbose=False,
save_best_only=True))
callbacks.append(CSVLogger(args.train_log))
logger.info(f"Use early stopping: {args.stop_early}")
if args.stop_early:
callbacks.append(EarlyStopping(monitor='val_loss',
patience=5,
verbose=False))
server_tr = np.max(server_windows_tr, axis=1)
if args.class_weights:
logger.info("class weights: compute custom weights")
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
logger.info(custom_class_weights)
else:
logger.info("class weights: set default")
custom_class_weights = None
logger.info(f"Load pretrained model")
embedding, model = load_model(source, custom_objects=models.get_metrics())
if args.model_type in ("inter", "staggered"):
server_tr = np.expand_dims(server_windows_tr, 2)
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
if args.model_output == "both":
labels = {"client": client_tr.value, "server": server_tr}
elif args.model_output == "client":
labels = {"client": client_tr.value}
elif args.model_output == "server":
labels = {"server": server_tr}
else:
raise ValueError("unknown model output")
logger.info("re-train model")
embedding.summary()
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
class_weight=custom_class_weights,
initial_epoch=args.initial_epoch)
def main_test(): def main_test():
logger.info("start test: load data") logger.info("start test: load data")
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.test_h5data, domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.test_h5data,
@ -233,7 +296,7 @@ def main_test():
for model_args in get_model_args(args): for model_args in get_model_args(args):
results = {} results = {}
logger.info(f"process model {model_args['model_path']}") logger.info(f"process model {model_args['model_path']}")
clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics()) embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics())
pred = clf_model.predict([domain_val, flow_val], pred = clf_model.predict([domain_val, flow_val],
batch_size=args.batch_size, batch_size=args.batch_size,
@ -248,7 +311,6 @@ def main_test():
else: else:
results["server_pred"] = pred results["server_pred"] = pred
embd_model = load_model(model_args["embedding_model"], custom_objects=models.get_metrics())
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1) domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
results["domain_embds"] = domain_embeddings results["domain_embds"] = domain_embeddings
@ -278,7 +340,7 @@ def main_visualization():
df_paul_user = df_paul.groupby(df_paul.names).max() df_paul_user = df_paul.groupby(df_paul.names).max()
logger.info("plot model") logger.info("plot model")
model = load_model(args.clf_model, custom_objects=models.get_metrics()) embd, model = load_model(args.clf_model, custom_objects=models.get_metrics())
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png")) visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
# logger.info("plot training curve") # logger.info("plot training curve")
@ -492,6 +554,16 @@ def main_beta():
joblib.dump(results, f"{path}/curves.joblib") joblib.dump(results, f"{path}/curves.joblib")
plot_overall_result()
def plot_overall_result():
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
try:
results = joblib.load(f"{path}/curves.joblib")
except Exception:
results = {}
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
x = np.linspace(0, 1, 10000) x = np.linspace(0, 1, 10000)
for vis in ["window_prc", "window_roc", "user_prc", "user_roc"]: for vis in ["window_prc", "window_roc", "user_prc", "user_roc"]:
@ -500,7 +572,7 @@ def main_beta():
for model_key in results.keys(): for model_key in results.keys():
ys_mean, ys_std, score = results[model_key][vis] ys_mean, ys_std, score = results[model_key][vis]
plt.plot(x, ys_mean, label=f"{model_key} - {score:5.4}") plt.plot(x, ys_mean, label=f"{model_key} - {score:5.4}")
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1) plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, alpha=0.2)
if vis.endswith("prc"): if vis.endswith("prc"):
plt.xlabel('Recall') plt.xlabel('Recall')
plt.ylabel('Precision') plt.ylabel('Precision')
@ -516,6 +588,8 @@ def main_beta():
def main(): def main():
if "train" == args.mode: if "train" == args.mode:
main_train() main_train()
if "retrain" == args.mode:
main_retrain()
if "hyperband" == args.mode: if "hyperband" == args.mode:
main_hyperband() main_hyperband()
if "test" == args.mode: if "test" == args.mode:
@ -530,6 +604,8 @@ def main():
main_paul_best() main_paul_best()
if "beta" == args.mode: if "beta" == args.mode:
main_beta() main_beta()
if "beta_all" == args.mode:
plot_overall_result()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -84,14 +84,17 @@ def calc_pr_mean(y, y_preds):
return ys_mean, ys_std, scores_mean return ys_mean, ys_std, scores_mean
def plot_mean_curve(x, ys, std, score, label):
plt.plot(x, ys, label=f"{label} - {score:5.4}")
plt.fill_between(x, ys - std, ys + std, alpha=0.1)
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
def plot_pr_mean(y, y_preds, label=""): def plot_pr_mean(y, y_preds, label=""):
x = np.linspace(0, 1, 10000) x = np.linspace(0, 1, 10000)
ys_mean, ys_std, score = calc_pr_mean(y, y_preds) ys_mean, ys_std, score = calc_pr_mean(y, y_preds)
plot_mean_curve(x, ys_mean, ys_std, score, label)
plt.plot(x, ys_mean, label=f"{label} - {score:5.4}")
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
plt.xlabel('Recall') plt.xlabel('Recall')
plt.ylabel('Precision') plt.ylabel('Precision')
@ -142,13 +145,9 @@ def calc_roc_mean(y, y_preds):
def plot_roc_mean(y, y_preds, label=""): def plot_roc_mean(y, y_preds, label=""):
x = np.linspace(0, 1, 10000) x = np.linspace(0, 1, 10000)
ys_mean, ys_std, auc_mean = calc_roc_mean(y, y_preds) ys_mean, ys_std, score = calc_roc_mean(y, y_preds)
plt.xscale('log') plt.xscale('log')
plt.ylim([0.0, 1.0]) plot_mean_curve(x, ys_mean, ys_std, score, label)
plt.xlim([0.0, 1.0])
plt.plot(x, ys_mean, label=f"{label} - {auc_mean:5.4}")
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
plt.xlabel('False Positive Rate') plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate') plt.ylabel('True Positive Rate')