add hyperband savefile config, minor change of parameter name

This commit is contained in:
René Knaebel 2017-10-03 18:58:54 +02:00
parent 68254d6629
commit 371a1dad05
4 changed files with 13 additions and 4 deletions

View File

@ -15,6 +15,9 @@ parser.add_argument("--data", action="store", dest="train_data",
parser.add_argument("--test", action="store", dest="test_data",
default="data/full_future_dataset.csv.tar.gz")
parser.add_argument("--hyper_result", action="store", dest="hyperband_results",
default="")
parser.add_argument("--model", action="store", dest="model_path",
default="results/model_x")

View File

@ -7,6 +7,7 @@ from math import ceil, log
from random import random as rng
from time import ctime, time
import joblib
import numpy as np
from keras.callbacks import EarlyStopping
@ -24,7 +25,7 @@ def sample_params(param_distribution: dict):
class Hyperband:
def __init__(self, param_distribution, X, y, max_iter=81):
def __init__(self, param_distribution, X, y, max_iter=81, savefile=None):
self.get_params = lambda: sample_params(param_distribution)
self.max_iter = max_iter # maximum iterations per configuration
@ -39,6 +40,8 @@ class Hyperband:
self.best_loss = np.inf
self.best_counter = -1
self.savefile = savefile
self.X = X
self.y = y
@ -143,4 +146,7 @@ class Hyperband:
random_configs = [random_configs[i] for i in indices if not early_stops[i]]
random_configs = random_configs[0:int(n_configs / self.eta)]
if self.savefile:
joblib.dump(self.results, self.savefile)
return self.results

View File

@ -63,7 +63,7 @@ PARAMS = {
#
'dropout': 0.5, # currently fix
'domain_features': args.domain_embedding,
'embedding_size': args.embedding,
'embedding': args.embedding,
'flow_features': 3,
'filter_embedding': args.filter_embedding,
'dense_embedding': args.dense_embedding,
@ -132,7 +132,7 @@ def main_hyperband():
[domain_tr, flow_tr],
[client_tr, server_tr])
results = hp.run()
joblib.dump(results, "hyperband.joblib")
joblib.dump(results, args.hyperband_results)
def main_train(param=None):

View File

@ -9,7 +9,7 @@ def get_models_by_params(params: dict):
# mainly embedding model
network_type = params.get("type")
network_depth = params.get("depth")
embedding_size = params.get("embedding_size")
embedding_size = params.get("embedding")
input_length = params.get("input_length")
filter_embedding = params.get("filter_embedding")
kernel_embedding = params.get("kernel_embedding")