From 9ce11e4db460b87a14e3372f524be25ce67898bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Thu, 9 Nov 2017 13:12:50 +0100 Subject: [PATCH] refactor test function according to the new training procedure --- main.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index dd7b95a..430bf95 100644 --- a/main.py +++ b/main.py @@ -375,11 +375,13 @@ def main_test(): logger.info("start test: load data") domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window) domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length) - - for model_args in get_model_args(args): - results = {} - logger.info(f"process model {model_args['model_path']}") - embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects()) + + results = {} + for model_path in args.model_paths: + file = get_dir(model_path)[1] + results[file] = {} + logger.info(f"process model {model_path}") + embd_model, clf_model = load_model(model_path, custom_objects=models.get_custom_objects()) pred = clf_model.predict([domain_val, flow_val], batch_size=args.batch_size, @@ -387,17 +389,17 @@ def main_test(): if args.model_output == "both": c_pred, s_pred = pred - results["client_pred"] = c_pred - results["server_pred"] = s_pred + results[file]["client_pred"] = c_pred + results[file]["server_pred"] = s_pred elif args.model_output == "client": - results["client_pred"] = pred + results[file]["client_pred"] = pred else: - results["server_pred"] = pred + results[file]["server_pred"] = pred domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1) results["domain_embds"] = domain_embeddings - dataset.save_predictions(model_args["model_path"], results) + dataset.save_predictions(get_dir(model_path)[0], results) def main_visualization():