add retrain mode

This commit is contained in:
René Knaebel 2017-09-28 12:23:22 +02:00
parent b157ca6a19
commit 090c89a127
4 changed files with 126 additions and 42 deletions

View File

@ -1,38 +1,38 @@
run:
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_1 --epochs 2 --depth small \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_1 --epochs 2 --depth flat1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_2 --epochs 2 --depth small \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_2 --epochs 2 --depth flat1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_3 --epochs 2 --depth medium \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_3 --epochs 2 --depth deep1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_4 --epochs 2 --depth medium \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_4 --epochs 2 --depth deep1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_5 --epochs 2 --depth small \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_5 --epochs 2 --depth flat2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_1 --epochs 2 --depth small \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_1 --epochs 2 --depth flat2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_2 --epochs 2 --depth small \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_2 --epochs 2 --depth flat2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output client
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_3 --epochs 2 --depth medium \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_3 --epochs 2 --depth deep1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_4 --epochs 2 --depth medium \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_4 --epochs 2 --depth deep1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client
test:

View File

@ -19,6 +19,12 @@ parser.add_argument("--test", action="store", dest="test_data",
parser.add_argument("--model", action="store", dest="model_path",
default="results/model_x")
parser.add_argument("--model_src", action="store", dest="model_source",
default="results/model_x")
parser.add_argument("--model_dest", action="store", dest="model_destination",
default="results/model_x")
parser.add_argument("--models", action="store", dest="model_paths", nargs="+",
default=[])
@ -37,6 +43,9 @@ parser.add_argument("--batch", action="store", dest="batch_size",
parser.add_argument("--epochs", action="store", dest="epochs",
default=10, type=int)
parser.add_argument("--init_epoch", action="store", dest="initial_epoch",
default=0, type=int)
# parser.add_argument("--samples", action="store", dest="samples",
# default=100000, type=int)
#
@ -98,7 +107,6 @@ parser.add_argument("--gpu", action="store_true", dest="gpu")
parser.add_argument("--new_model", action="store_true", dest="new_model")
def get_model_args(args):
return [{
"model_path": model_path,
@ -111,6 +119,7 @@ def get_model_args(args):
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred")
} for model_path in args.model_paths]
def parse():
args = parser.parse_args()
args.result_path = os.path.split(os.path.normpath(args.output_prefix))[1]

96
main.py
View File

@ -5,8 +5,8 @@ import os
import numpy as np
import pandas as pd
import tensorflow as tf
from keras.callbacks import CSVLogger, EarlyStopping, LambdaCallback, ModelCheckpoint
from keras.models import Model, load_model
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
from keras.models import Model, load_model as load_keras_model
import arguments
import dataset
@ -86,6 +86,12 @@ def create_model(model, output_type):
raise Exception("unknown model output")
def load_model(path, custom_objects=None):
clf = load_keras_model(path, custom_objects)
embd = clf.get_layer("domain_cnn").layer
return embd, clf
def main_paul_best():
pauls_best_params = models.pauls_networks.best_config
main_train(pauls_best_params)
@ -161,10 +167,6 @@ def main_train(param=None):
logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param)
callbacks.append(LambdaCallback(
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
)
model = create_model(model, args.model_output)
new_model = create_model(new_model, args.model_output)
@ -222,6 +224,67 @@ def main_train(param=None):
class_weight=custom_class_weights)
def main_retrain():
source = os.path.join(args.model_source, "clf.h5")
destination = os.path.join(args.model_destination, "clf.h5")
logger.info(f"Use command line arguments: {args}")
exists_or_make_path(destination)
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.train_h5data,
args.train_data,
args.domain_length,
args.window)
logger.info("define callbacks")
callbacks = []
callbacks.append(ModelCheckpoint(filepath=destination,
monitor='loss',
verbose=False,
save_best_only=True))
callbacks.append(CSVLogger(args.train_log))
logger.info(f"Use early stopping: {args.stop_early}")
if args.stop_early:
callbacks.append(EarlyStopping(monitor='val_loss',
patience=5,
verbose=False))
server_tr = np.max(server_windows_tr, axis=1)
if args.class_weights:
logger.info("class weights: compute custom weights")
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
logger.info(custom_class_weights)
else:
logger.info("class weights: set default")
custom_class_weights = None
logger.info(f"Load pretrained model")
embedding, model = load_model(source, custom_objects=models.get_metrics())
if args.model_type in ("inter", "staggered"):
server_tr = np.expand_dims(server_windows_tr, 2)
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
if args.model_output == "both":
labels = {"client": client_tr.value, "server": server_tr}
elif args.model_output == "client":
labels = {"client": client_tr.value}
elif args.model_output == "server":
labels = {"server": server_tr}
else:
raise ValueError("unknown model output")
logger.info("re-train model")
embedding.summary()
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
class_weight=custom_class_weights,
initial_epoch=args.initial_epoch)
def main_test():
logger.info("start test: load data")
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.test_h5data,
@ -233,7 +296,7 @@ def main_test():
for model_args in get_model_args(args):
results = {}
logger.info(f"process model {model_args['model_path']}")
clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics())
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics())
pred = clf_model.predict([domain_val, flow_val],
batch_size=args.batch_size,
@ -248,7 +311,6 @@ def main_test():
else:
results["server_pred"] = pred
embd_model = load_model(model_args["embedding_model"], custom_objects=models.get_metrics())
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
results["domain_embds"] = domain_embeddings
@ -278,7 +340,7 @@ def main_visualization():
df_paul_user = df_paul.groupby(df_paul.names).max()
logger.info("plot model")
model = load_model(args.clf_model, custom_objects=models.get_metrics())
embd, model = load_model(args.clf_model, custom_objects=models.get_metrics())
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
# logger.info("plot training curve")
@ -491,6 +553,16 @@ def main_beta():
visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.png")
joblib.dump(results, f"{path}/curves.joblib")
plot_overall_result()
def plot_overall_result():
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
try:
results = joblib.load(f"{path}/curves.joblib")
except Exception:
results = {}
import matplotlib.pyplot as plt
x = np.linspace(0, 1, 10000)
@ -500,7 +572,7 @@ def main_beta():
for model_key in results.keys():
ys_mean, ys_std, score = results[model_key][vis]
plt.plot(x, ys_mean, label=f"{model_key} - {score:5.4}")
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, alpha=0.2)
if vis.endswith("prc"):
plt.xlabel('Recall')
plt.ylabel('Precision')
@ -516,6 +588,8 @@ def main_beta():
def main():
if "train" == args.mode:
main_train()
if "retrain" == args.mode:
main_retrain()
if "hyperband" == args.mode:
main_hyperband()
if "test" == args.mode:
@ -530,6 +604,8 @@ def main():
main_paul_best()
if "beta" == args.mode:
main_beta()
if "beta_all" == args.mode:
plot_overall_result()
if __name__ == "__main__":

View File

@ -84,14 +84,17 @@ def calc_pr_mean(y, y_preds):
return ys_mean, ys_std, scores_mean
def plot_mean_curve(x, ys, std, score, label):
plt.plot(x, ys, label=f"{label} - {score:5.4}")
plt.fill_between(x, ys - std, ys + std, alpha=0.1)
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
def plot_pr_mean(y, y_preds, label=""):
x = np.linspace(0, 1, 10000)
ys_mean, ys_std, score = calc_pr_mean(y, y_preds)
plt.plot(x, ys_mean, label=f"{label} - {score:5.4}")
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
plot_mean_curve(x, ys_mean, ys_std, score, label)
plt.xlabel('Recall')
plt.ylabel('Precision')
@ -142,13 +145,9 @@ def calc_roc_mean(y, y_preds):
def plot_roc_mean(y, y_preds, label=""):
x = np.linspace(0, 1, 10000)
ys_mean, ys_std, auc_mean = calc_roc_mean(y, y_preds)
ys_mean, ys_std, score = calc_roc_mean(y, y_preds)
plt.xscale('log')
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
plt.plot(x, ys_mean, label=f"{label} - {auc_mean:5.4}")
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
plot_mean_curve(x, ys_mean, ys_std, score, label)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')