add parameter for hyper band iteration, use hyperband results in new runs

This commit is contained in:
René Knaebel 2017-11-08 11:09:56 +01:00
parent 903e81c931
commit 9b8ca8abab
2 changed files with 14 additions and 6 deletions

View File

@ -52,6 +52,8 @@ parser.add_argument("--init_epoch", action="store", dest="initial_epoch",
parser.add_argument("--runs", action="store", dest="runs",
default=20, type=int)
parser.add_argument("--hyper_max_iter", action="store", dest="hyper_max_iter",
default=81, type=int)
# parser.add_argument("--samples", action="store", dest="samples",
# default=100000, type=int)

18
main.py
View File

@ -155,7 +155,7 @@ def main_paul_best():
main_train(pauls_best_params)
def main_hyperband(data, domain_length, window_size, model_type, result_file, dist_size="small"):
def main_hyperband(data, domain_length, window_size, model_type, result_file, max_iter, dist_size="small"):
param_dist = get_param_dist(dist_size)
logger.info("create training dataset")
@ -167,8 +167,8 @@ def main_hyperband(data, domain_length, window_size, model_type, result_file, di
server_tr = np.expand_dims(server_windows_tr, 2)
domain_tr, flow_tr, client_tr, server_tr = shuffle_training_data(domain_tr, flow_tr, client_tr, server_tr)
return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, 81, result_file)
return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, max_iter, result_file)
def run_hyperband(dist_size, domain, flow, client, server, max_iter, savefile):
@ -208,9 +208,14 @@ def main_train(param=None):
# call hyperband if used
if args.hyperband_results:
logger.info("start hyperband parameter search")
hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, 81, args.hyperband_results)
try:
hyper_results = joblib.load(args.hyperband_results)
except Exception:
logger.info("start hyperband parameter search")
hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, args.hyper_max_iter,
args.hyperband_results)
param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"]
param["type"] = args.model_type
logger.info(f"select params from result: {param}")
if not param:
param = PARAMS
@ -815,7 +820,8 @@ def main():
if "retrain" == args.mode:
main_retrain()
if "hyperband" == args.mode:
main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results)
main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results,
arg.hyper_max_iter)
if "test" == args.mode:
main_test()
if "fancy" == args.mode: