fix hyperband wrong variable names
This commit is contained in:
parent
a1e553f8f1
commit
f02e0b7f99
@ -46,8 +46,8 @@ class Hyperband:
|
||||
n_iterations = int(round(n_iterations))
|
||||
embedding, model, new_model = models.get_models_by_params(params)
|
||||
|
||||
model = create_model(model, params["output"])
|
||||
new_model = create_model(new_model, params["output"])
|
||||
model = create_model(model, params["model_output"])
|
||||
new_model = create_model(new_model, params["model_output"])
|
||||
|
||||
if params["type"] in ("inter", "staggered"):
|
||||
model = new_model
|
||||
@ -68,7 +68,8 @@ class Hyperband:
|
||||
shuffle=True,
|
||||
validation_split=0.4)
|
||||
|
||||
return {"loss": history.history['val_loss'][-1], "early_stop": True}
|
||||
return {"loss": history.history['val_loss'][-1],
|
||||
"early_stop": len(history.history["loss"]) < n_iterations}
|
||||
|
||||
# can be called multiple times
|
||||
def run(self, skip_last=0, dry_run=False):
|
||||
|
7
main.py
7
main.py
@ -101,11 +101,12 @@ def main_hyperband():
|
||||
# static params
|
||||
"type": [args.model_type],
|
||||
"depth": [args.model_depth],
|
||||
"output": [args.model_output],
|
||||
"model_output": [args.model_output],
|
||||
"batch_size": [args.batch_size],
|
||||
"window_size": [10],
|
||||
"window_size": [args.window],
|
||||
"flow_features": [3],
|
||||
"input_length": [40],
|
||||
"domain_length": [args.domain_length],
|
||||
'input_length': [40],
|
||||
# model params
|
||||
"embedding_size": [2 ** x for x in range(3, 7)],
|
||||
"filter_embedding": [2 ** x for x in range(1, 10)],
|
||||
|
Loading…
Reference in New Issue
Block a user