From e24f596f40c8adeb801d7680c87e6dc40a4675a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Sun, 30 Jul 2017 14:07:39 +0200 Subject: [PATCH] add argument for using the new model architecture --- arguments.py | 1 + main.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/arguments.py b/arguments.py index b32432a..6559fab 100644 --- a/arguments.py +++ b/arguments.py @@ -63,6 +63,7 @@ parser.add_argument("--domain_embd", action="store", dest="domain_embedding", parser.add_argument("--stop_early", action="store_true", dest="stop_early") parser.add_argument("--balanced_weights", action="store_true", dest="class_weights") parser.add_argument("--gpu", action="store_true", dest="gpu") +parser.add_argument("--new_model", action="store_true", dest="new_model") diff --git a/main.py b/main.py index 37568ed..8ccb6e8 100644 --- a/main.py +++ b/main.py @@ -112,18 +112,18 @@ def main_hyperband(): json.dump(results, open("hyperband.json")) -def main_train(param=None, train_new_model=False): +def main_train(param=None): + logger.info(f"Create model path {args.model_path}") exists_or_make_path(args.model_path) + logger.info(f"Use command line arguments: {args}") domain_tr, flow_tr, client_tr, server_windows_tr = load_or_generate_h5data(args.train_h5data, args.train_data, args.domain_length, args.window) if not param: param = PARAMS - + logger.info(f"Generator model with params: {param}") embedding, model, new_model = models.get_models_by_params(param) - embedding.summary() - model.summary() logger.info("define callbacks") callbacks = [] callbacks.append(ModelCheckpoint(filepath=args.clf_model, @@ -131,11 +131,11 @@ def main_train(param=None, train_new_model=False): verbose=False, save_best_only=True)) callbacks.append(CSVLogger(args.train_log)) + logger.info(f"Use early stopping: {args.stop_early}") if args.stop_early: callbacks.append(EarlyStopping(monitor='val_loss', patience=5, verbose=False)) - logger.info("compile model") custom_metrics = models.get_metric_functions() server_tr = np.max(server_windows_tr, axis=1) @@ -147,12 +147,14 @@ def main_train(param=None, train_new_model=False): else: logger.info("class weights: set default") custom_class_weights = None - logger.info("start training") - if train_new_model: + logger.info(f"select model: {'new' if args.new_model else 'old'}") + if args.new_model: server_tr = np.expand_dims(server_windows_tr, 2) model = new_model - + logger.info("compile and train model") + embedding.summary() + model.summary() model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'] + custom_metrics) @@ -271,8 +273,6 @@ def main(): main_paul_best() if "data" in args.modes: main_data() - if "train_new" in args.modes: - main_train(train_new_model=True) if __name__ == "__main__":