refactor test function according to the new training procedure

This commit is contained in:
René Knaebel 2017-11-09 13:12:50 +01:00
parent 9b8ca8abab
commit 9ce11e4db4
1 changed files with 12 additions and 10 deletions

22
main.py
View File

@ -375,11 +375,13 @@ def main_test():
logger.info("start test: load data")
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
for model_args in get_model_args(args):
results = {}
logger.info(f"process model {model_args['model_path']}")
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects())
results = {}
for model_path in args.model_paths:
file = get_dir(model_path)[1]
results[file] = {}
logger.info(f"process model {model_path}")
embd_model, clf_model = load_model(model_path, custom_objects=models.get_custom_objects())
pred = clf_model.predict([domain_val, flow_val],
batch_size=args.batch_size,
@ -387,17 +389,17 @@ def main_test():
if args.model_output == "both":
c_pred, s_pred = pred
results["client_pred"] = c_pred
results["server_pred"] = s_pred
results[file]["client_pred"] = c_pred
results[file]["server_pred"] = s_pred
elif args.model_output == "client":
results["client_pred"] = pred
results[file]["client_pred"] = pred
else:
results["server_pred"] = pred
results[file]["server_pred"] = pred
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
results["domain_embds"] = domain_embeddings
dataset.save_predictions(model_args["model_path"], results)
dataset.save_predictions(get_dir(model_path)[0], results)
def main_visualization():