fix hyperband wrong variable names

This commit is contained in:
René Knaebel 2017-09-29 23:34:39 +02:00
parent a1e553f8f1
commit f02e0b7f99
2 changed files with 8 additions and 6 deletions

View File

@ -46,8 +46,8 @@ class Hyperband:
n_iterations = int(round(n_iterations))
embedding, model, new_model = models.get_models_by_params(params)
model = create_model(model, params["output"])
new_model = create_model(new_model, params["output"])
model = create_model(model, params["model_output"])
new_model = create_model(new_model, params["model_output"])
if params["type"] in ("inter", "staggered"):
model = new_model
@ -68,7 +68,8 @@ class Hyperband:
shuffle=True,
validation_split=0.4)
return {"loss": history.history['val_loss'][-1], "early_stop": True}
return {"loss": history.history['val_loss'][-1],
"early_stop": len(history.history["loss"]) < n_iterations}
# can be called multiple times
def run(self, skip_last=0, dry_run=False):

View File

@ -101,11 +101,12 @@ def main_hyperband():
# static params
"type": [args.model_type],
"depth": [args.model_depth],
"output": [args.model_output],
"model_output": [args.model_output],
"batch_size": [args.batch_size],
"window_size": [10],
"window_size": [args.window],
"flow_features": [3],
"input_length": [40],
"domain_length": [args.domain_length],
'input_length': [40],
# model params
"embedding_size": [2 ** x for x in range(3, 7)],
"filter_embedding": [2 ** x for x in range(1, 10)],