From 371a1dad052b77a5e31b46e29a832fc2c3009c59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Tue, 3 Oct 2017 18:58:54 +0200 Subject: [PATCH] add hyperband savefile config, minor change of parameter name --- arguments.py | 3 +++ hyperband.py | 8 +++++++- main.py | 4 ++-- models/__init__.py | 2 +- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/arguments.py b/arguments.py index 51d0712..cf01292 100644 --- a/arguments.py +++ b/arguments.py @@ -15,6 +15,9 @@ parser.add_argument("--data", action="store", dest="train_data", parser.add_argument("--test", action="store", dest="test_data", default="data/full_future_dataset.csv.tar.gz") +parser.add_argument("--hyper_result", action="store", dest="hyperband_results", + default="") + parser.add_argument("--model", action="store", dest="model_path", default="results/model_x") diff --git a/hyperband.py b/hyperband.py index 9af243c..835e844 100644 --- a/hyperband.py +++ b/hyperband.py @@ -7,6 +7,7 @@ from math import ceil, log from random import random as rng from time import ctime, time +import joblib import numpy as np from keras.callbacks import EarlyStopping @@ -24,7 +25,7 @@ def sample_params(param_distribution: dict): class Hyperband: - def __init__(self, param_distribution, X, y, max_iter=81): + def __init__(self, param_distribution, X, y, max_iter=81, savefile=None): self.get_params = lambda: sample_params(param_distribution) self.max_iter = max_iter # maximum iterations per configuration @@ -39,6 +40,8 @@ class Hyperband: self.best_loss = np.inf self.best_counter = -1 + self.savefile = savefile + self.X = X self.y = y @@ -143,4 +146,7 @@ class Hyperband: random_configs = [random_configs[i] for i in indices if not early_stops[i]] random_configs = random_configs[0:int(n_configs / self.eta)] + if self.savefile: + joblib.dump(self.results, self.savefile) + return self.results diff --git a/main.py b/main.py index 3beede2..06b6b4d 100644 --- a/main.py +++ b/main.py @@ -63,7 +63,7 @@ PARAMS = { # 'dropout': 0.5, # currently fix 'domain_features': args.domain_embedding, - 'embedding_size': args.embedding, + 'embedding': args.embedding, 'flow_features': 3, 'filter_embedding': args.filter_embedding, 'dense_embedding': args.dense_embedding, @@ -132,7 +132,7 @@ def main_hyperband(): [domain_tr, flow_tr], [client_tr, server_tr]) results = hp.run() - joblib.dump(results, "hyperband.joblib") + joblib.dump(results, args.hyperband_results) def main_train(param=None): diff --git a/models/__init__.py b/models/__init__.py index 40e3f52..75595c2 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -9,7 +9,7 @@ def get_models_by_params(params: dict): # mainly embedding model network_type = params.get("type") network_depth = params.get("depth") - embedding_size = params.get("embedding_size") + embedding_size = params.get("embedding") input_length = params.get("input_length") filter_embedding = params.get("filter_embedding") kernel_embedding = params.get("kernel_embedding")