From f02e0b7f99980f6f1f7f864d14a12795f50bb003 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Fri, 29 Sep 2017 23:34:39 +0200 Subject: [PATCH] fix hyperband wrong variable names --- hyperband.py | 7 ++++--- main.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/hyperband.py b/hyperband.py index 3a42417..9af243c 100644 --- a/hyperband.py +++ b/hyperband.py @@ -46,8 +46,8 @@ class Hyperband: n_iterations = int(round(n_iterations)) embedding, model, new_model = models.get_models_by_params(params) - model = create_model(model, params["output"]) - new_model = create_model(new_model, params["output"]) + model = create_model(model, params["model_output"]) + new_model = create_model(new_model, params["model_output"]) if params["type"] in ("inter", "staggered"): model = new_model @@ -68,7 +68,8 @@ class Hyperband: shuffle=True, validation_split=0.4) - return {"loss": history.history['val_loss'][-1], "early_stop": True} + return {"loss": history.history['val_loss'][-1], + "early_stop": len(history.history["loss"]) < n_iterations} # can be called multiple times def run(self, skip_last=0, dry_run=False): diff --git a/main.py b/main.py index 5332c79..3beede2 100644 --- a/main.py +++ b/main.py @@ -101,11 +101,12 @@ def main_hyperband(): # static params "type": [args.model_type], "depth": [args.model_depth], - "output": [args.model_output], + "model_output": [args.model_output], "batch_size": [args.batch_size], - "window_size": [10], + "window_size": [args.window], "flow_features": [3], - "input_length": [40], + "domain_length": [args.domain_length], + 'input_length': [40], # model params "embedding_size": [2 ** x for x in range(3, 7)], "filter_embedding": [2 ** x for x in range(1, 10)],