refactor test function according to the new training procedure
This commit is contained in:
parent
9b8ca8abab
commit
9ce11e4db4
22
main.py
22
main.py
@ -375,11 +375,13 @@ def main_test():
|
||||
logger.info("start test: load data")
|
||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
|
||||
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
||||
|
||||
for model_args in get_model_args(args):
|
||||
results = {}
|
||||
logger.info(f"process model {model_args['model_path']}")
|
||||
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects())
|
||||
|
||||
results = {}
|
||||
for model_path in args.model_paths:
|
||||
file = get_dir(model_path)[1]
|
||||
results[file] = {}
|
||||
logger.info(f"process model {model_path}")
|
||||
embd_model, clf_model = load_model(model_path, custom_objects=models.get_custom_objects())
|
||||
|
||||
pred = clf_model.predict([domain_val, flow_val],
|
||||
batch_size=args.batch_size,
|
||||
@ -387,17 +389,17 @@ def main_test():
|
||||
|
||||
if args.model_output == "both":
|
||||
c_pred, s_pred = pred
|
||||
results["client_pred"] = c_pred
|
||||
results["server_pred"] = s_pred
|
||||
results[file]["client_pred"] = c_pred
|
||||
results[file]["server_pred"] = s_pred
|
||||
elif args.model_output == "client":
|
||||
results["client_pred"] = pred
|
||||
results[file]["client_pred"] = pred
|
||||
else:
|
||||
results["server_pred"] = pred
|
||||
results[file]["server_pred"] = pred
|
||||
|
||||
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||
results["domain_embds"] = domain_embeddings
|
||||
|
||||
dataset.save_predictions(model_args["model_path"], results)
|
||||
dataset.save_predictions(get_dir(model_path)[0], results)
|
||||
|
||||
|
||||
def main_visualization():
|
||||
|
Loading…
Reference in New Issue
Block a user