diff --git a/main.py b/main.py index dd7b95a..430bf95 100644 --- a/main.py +++ b/main.py @@ -375,11 +375,13 @@ def main_test(): logger.info("start test: load data") domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window) domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length) - - for model_args in get_model_args(args): - results = {} - logger.info(f"process model {model_args['model_path']}") - embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects()) + + results = {} + for model_path in args.model_paths: + file = get_dir(model_path)[1] + results[file] = {} + logger.info(f"process model {model_path}") + embd_model, clf_model = load_model(model_path, custom_objects=models.get_custom_objects()) pred = clf_model.predict([domain_val, flow_val], batch_size=args.batch_size, @@ -387,17 +389,17 @@ def main_test(): if args.model_output == "both": c_pred, s_pred = pred - results["client_pred"] = c_pred - results["server_pred"] = s_pred + results[file]["client_pred"] = c_pred + results[file]["server_pred"] = s_pred elif args.model_output == "client": - results["client_pred"] = pred + results[file]["client_pred"] = pred else: - results["server_pred"] = pred + results[file]["server_pred"] = pred domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1) results["domain_embds"] = domain_embeddings - dataset.save_predictions(model_args["model_path"], results) + dataset.save_predictions(get_dir(model_path)[0], results) def main_visualization():