From 9b8ca8abab864352c3d2b7a2511c7e76dbd2e9b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Wed, 8 Nov 2017 11:09:56 +0100 Subject: [PATCH] add parameter for hyper band iteration, use hyperband results in new runs --- arguments.py | 2 ++ main.py | 18 ++++++++++++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/arguments.py b/arguments.py index d24b37e..800f9ad 100644 --- a/arguments.py +++ b/arguments.py @@ -52,6 +52,8 @@ parser.add_argument("--init_epoch", action="store", dest="initial_epoch", parser.add_argument("--runs", action="store", dest="runs", default=20, type=int) +parser.add_argument("--hyper_max_iter", action="store", dest="hyper_max_iter", + default=81, type=int) # parser.add_argument("--samples", action="store", dest="samples", # default=100000, type=int) diff --git a/main.py b/main.py index e95c7cb..dd7b95a 100644 --- a/main.py +++ b/main.py @@ -155,7 +155,7 @@ def main_paul_best(): main_train(pauls_best_params) -def main_hyperband(data, domain_length, window_size, model_type, result_file, dist_size="small"): +def main_hyperband(data, domain_length, window_size, model_type, result_file, max_iter, dist_size="small"): param_dist = get_param_dist(dist_size) logger.info("create training dataset") @@ -167,8 +167,8 @@ def main_hyperband(data, domain_length, window_size, model_type, result_file, di server_tr = np.expand_dims(server_windows_tr, 2) domain_tr, flow_tr, client_tr, server_tr = shuffle_training_data(domain_tr, flow_tr, client_tr, server_tr) - - return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, 81, result_file) + + return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, max_iter, result_file) def run_hyperband(dist_size, domain, flow, client, server, max_iter, savefile): @@ -208,9 +208,14 @@ def main_train(param=None): # call hyperband if used if args.hyperband_results: - logger.info("start hyperband parameter search") - hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, 81, args.hyperband_results) + try: + hyper_results = joblib.load(args.hyperband_results) + except Exception: + logger.info("start hyperband parameter search") + hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, args.hyper_max_iter, + args.hyperband_results) param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"] + param["type"] = args.model_type logger.info(f"select params from result: {param}") if not param: param = PARAMS @@ -815,7 +820,8 @@ def main(): if "retrain" == args.mode: main_retrain() if "hyperband" == args.mode: - main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results) + main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results, + arg.hyper_max_iter) if "test" == args.mode: main_test() if "fancy" == args.mode: