add hyperband savefile config, minor change of parameter name
This commit is contained in:
parent
68254d6629
commit
371a1dad05
@ -15,6 +15,9 @@ parser.add_argument("--data", action="store", dest="train_data",
|
||||
parser.add_argument("--test", action="store", dest="test_data",
|
||||
default="data/full_future_dataset.csv.tar.gz")
|
||||
|
||||
parser.add_argument("--hyper_result", action="store", dest="hyperband_results",
|
||||
default="")
|
||||
|
||||
|
||||
parser.add_argument("--model", action="store", dest="model_path",
|
||||
default="results/model_x")
|
||||
|
@ -7,6 +7,7 @@ from math import ceil, log
|
||||
from random import random as rng
|
||||
from time import ctime, time
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
from keras.callbacks import EarlyStopping
|
||||
|
||||
@ -24,7 +25,7 @@ def sample_params(param_distribution: dict):
|
||||
|
||||
|
||||
class Hyperband:
|
||||
def __init__(self, param_distribution, X, y, max_iter=81):
|
||||
def __init__(self, param_distribution, X, y, max_iter=81, savefile=None):
|
||||
self.get_params = lambda: sample_params(param_distribution)
|
||||
|
||||
self.max_iter = max_iter # maximum iterations per configuration
|
||||
@ -39,6 +40,8 @@ class Hyperband:
|
||||
self.best_loss = np.inf
|
||||
self.best_counter = -1
|
||||
|
||||
self.savefile = savefile
|
||||
|
||||
self.X = X
|
||||
self.y = y
|
||||
|
||||
@ -143,4 +146,7 @@ class Hyperband:
|
||||
random_configs = [random_configs[i] for i in indices if not early_stops[i]]
|
||||
random_configs = random_configs[0:int(n_configs / self.eta)]
|
||||
|
||||
if self.savefile:
|
||||
joblib.dump(self.results, self.savefile)
|
||||
|
||||
return self.results
|
||||
|
4
main.py
4
main.py
@ -63,7 +63,7 @@ PARAMS = {
|
||||
#
|
||||
'dropout': 0.5, # currently fix
|
||||
'domain_features': args.domain_embedding,
|
||||
'embedding_size': args.embedding,
|
||||
'embedding': args.embedding,
|
||||
'flow_features': 3,
|
||||
'filter_embedding': args.filter_embedding,
|
||||
'dense_embedding': args.dense_embedding,
|
||||
@ -132,7 +132,7 @@ def main_hyperband():
|
||||
[domain_tr, flow_tr],
|
||||
[client_tr, server_tr])
|
||||
results = hp.run()
|
||||
joblib.dump(results, "hyperband.joblib")
|
||||
joblib.dump(results, args.hyperband_results)
|
||||
|
||||
|
||||
def main_train(param=None):
|
||||
|
@ -9,7 +9,7 @@ def get_models_by_params(params: dict):
|
||||
# mainly embedding model
|
||||
network_type = params.get("type")
|
||||
network_depth = params.get("depth")
|
||||
embedding_size = params.get("embedding_size")
|
||||
embedding_size = params.get("embedding")
|
||||
input_length = params.get("input_length")
|
||||
filter_embedding = params.get("filter_embedding")
|
||||
kernel_embedding = params.get("kernel_embedding")
|
||||
|
Loading…
Reference in New Issue
Block a user