refactor test function according to the new training procedure
This commit is contained in:
parent
9b8ca8abab
commit
9ce11e4db4
18
main.py
18
main.py
@ -376,10 +376,12 @@ def main_test():
|
|||||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
|
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
|
||||||
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
||||||
|
|
||||||
for model_args in get_model_args(args):
|
|
||||||
results = {}
|
results = {}
|
||||||
logger.info(f"process model {model_args['model_path']}")
|
for model_path in args.model_paths:
|
||||||
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects())
|
file = get_dir(model_path)[1]
|
||||||
|
results[file] = {}
|
||||||
|
logger.info(f"process model {model_path}")
|
||||||
|
embd_model, clf_model = load_model(model_path, custom_objects=models.get_custom_objects())
|
||||||
|
|
||||||
pred = clf_model.predict([domain_val, flow_val],
|
pred = clf_model.predict([domain_val, flow_val],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
@ -387,17 +389,17 @@ def main_test():
|
|||||||
|
|
||||||
if args.model_output == "both":
|
if args.model_output == "both":
|
||||||
c_pred, s_pred = pred
|
c_pred, s_pred = pred
|
||||||
results["client_pred"] = c_pred
|
results[file]["client_pred"] = c_pred
|
||||||
results["server_pred"] = s_pred
|
results[file]["server_pred"] = s_pred
|
||||||
elif args.model_output == "client":
|
elif args.model_output == "client":
|
||||||
results["client_pred"] = pred
|
results[file]["client_pred"] = pred
|
||||||
else:
|
else:
|
||||||
results["server_pred"] = pred
|
results[file]["server_pred"] = pred
|
||||||
|
|
||||||
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||||
results["domain_embds"] = domain_embeddings
|
results["domain_embds"] = domain_embeddings
|
||||||
|
|
||||||
dataset.save_predictions(model_args["model_path"], results)
|
dataset.save_predictions(get_dir(model_path)[0], results)
|
||||||
|
|
||||||
|
|
||||||
def main_visualization():
|
def main_visualization():
|
||||||
|
Loading…
Reference in New Issue
Block a user