fix hyperband wrong variable names
This commit is contained in:
parent
a1e553f8f1
commit
f02e0b7f99
@ -46,8 +46,8 @@ class Hyperband:
|
|||||||
n_iterations = int(round(n_iterations))
|
n_iterations = int(round(n_iterations))
|
||||||
embedding, model, new_model = models.get_models_by_params(params)
|
embedding, model, new_model = models.get_models_by_params(params)
|
||||||
|
|
||||||
model = create_model(model, params["output"])
|
model = create_model(model, params["model_output"])
|
||||||
new_model = create_model(new_model, params["output"])
|
new_model = create_model(new_model, params["model_output"])
|
||||||
|
|
||||||
if params["type"] in ("inter", "staggered"):
|
if params["type"] in ("inter", "staggered"):
|
||||||
model = new_model
|
model = new_model
|
||||||
@ -68,7 +68,8 @@ class Hyperband:
|
|||||||
shuffle=True,
|
shuffle=True,
|
||||||
validation_split=0.4)
|
validation_split=0.4)
|
||||||
|
|
||||||
return {"loss": history.history['val_loss'][-1], "early_stop": True}
|
return {"loss": history.history['val_loss'][-1],
|
||||||
|
"early_stop": len(history.history["loss"]) < n_iterations}
|
||||||
|
|
||||||
# can be called multiple times
|
# can be called multiple times
|
||||||
def run(self, skip_last=0, dry_run=False):
|
def run(self, skip_last=0, dry_run=False):
|
||||||
|
7
main.py
7
main.py
@ -101,11 +101,12 @@ def main_hyperband():
|
|||||||
# static params
|
# static params
|
||||||
"type": [args.model_type],
|
"type": [args.model_type],
|
||||||
"depth": [args.model_depth],
|
"depth": [args.model_depth],
|
||||||
"output": [args.model_output],
|
"model_output": [args.model_output],
|
||||||
"batch_size": [args.batch_size],
|
"batch_size": [args.batch_size],
|
||||||
"window_size": [10],
|
"window_size": [args.window],
|
||||||
"flow_features": [3],
|
"flow_features": [3],
|
||||||
"input_length": [40],
|
"domain_length": [args.domain_length],
|
||||||
|
'input_length': [40],
|
||||||
# model params
|
# model params
|
||||||
"embedding_size": [2 ** x for x in range(3, 7)],
|
"embedding_size": [2 ** x for x in range(3, 7)],
|
||||||
"filter_embedding": [2 ** x for x in range(1, 10)],
|
"filter_embedding": [2 ** x for x in range(1, 10)],
|
||||||
|
Loading…
Reference in New Issue
Block a user