refactor test function according to the new training procedure

This commit is contained in:
René Knaebel 2017-11-09 13:12:50 +01:00
parent 9b8ca8abab
commit 9ce11e4db4

20
main.py
View File

@ -376,10 +376,12 @@ def main_test():
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window) domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length) domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
for model_args in get_model_args(args): results = {}
results = {} for model_path in args.model_paths:
logger.info(f"process model {model_args['model_path']}") file = get_dir(model_path)[1]
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects()) results[file] = {}
logger.info(f"process model {model_path}")
embd_model, clf_model = load_model(model_path, custom_objects=models.get_custom_objects())
pred = clf_model.predict([domain_val, flow_val], pred = clf_model.predict([domain_val, flow_val],
batch_size=args.batch_size, batch_size=args.batch_size,
@ -387,17 +389,17 @@ def main_test():
if args.model_output == "both": if args.model_output == "both":
c_pred, s_pred = pred c_pred, s_pred = pred
results["client_pred"] = c_pred results[file]["client_pred"] = c_pred
results["server_pred"] = s_pred results[file]["server_pred"] = s_pred
elif args.model_output == "client": elif args.model_output == "client":
results["client_pred"] = pred results[file]["client_pred"] = pred
else: else:
results["server_pred"] = pred results[file]["server_pred"] = pred
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1) domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
results["domain_embds"] = domain_embeddings results["domain_embds"] = domain_embeddings
dataset.save_predictions(model_args["model_path"], results) dataset.save_predictions(get_dir(model_path)[0], results)
def main_visualization(): def main_visualization():