169 lines
6.1 KiB
Python
169 lines
6.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
# implementation of hyperband:
|
|
# https://arxiv.org/pdf/1603.06560.pdf
|
|
import logging
|
|
import random
|
|
from math import ceil, log
|
|
from random import random as rng
|
|
from time import ctime, time
|
|
|
|
import joblib
|
|
import keras.backend as K
|
|
import numpy as np
|
|
from keras.callbacks import EarlyStopping
|
|
|
|
import models
|
|
from main import create_model
|
|
|
|
logger = logging.getLogger('cisco_logger')
|
|
|
|
|
|
def sample_params(param_distribution: dict):
|
|
p = {}
|
|
for key, val in param_distribution.items():
|
|
p[key] = random.choice(val)
|
|
return p
|
|
|
|
|
|
class Hyperband:
|
|
def __init__(self, param_distribution, X, y, max_iter=81, savefile=None):
|
|
self.get_params = lambda: sample_params(param_distribution)
|
|
|
|
self.max_iter = max_iter # maximum iterations per configuration
|
|
self.eta = 3 # defines configuration downsampling rate (default = 3)
|
|
|
|
self.logeta = lambda x: log(x) / log(self.eta)
|
|
self.s_max = int(self.logeta(self.max_iter))
|
|
self.B = (self.s_max + 1) * self.max_iter
|
|
|
|
self.results = [] # list of dicts
|
|
self.counter = 0
|
|
self.best_loss = np.inf
|
|
self.best_counter = -1
|
|
|
|
self.savefile = savefile
|
|
|
|
self.X = X
|
|
self.y = y
|
|
|
|
def try_params(self, n_iterations, params):
|
|
n_iterations = int(round(n_iterations))
|
|
embedding, model, new_model, long_model, soft_model = models.get_models_by_params(params)
|
|
|
|
if params["type"] in ("inter", "staggered"):
|
|
model = new_model
|
|
if params["type"] == "long":
|
|
model = long_model
|
|
if params["type"] == "soft":
|
|
model = soft_model
|
|
|
|
model = create_model(model, params["model_output"])
|
|
|
|
if params["type"] == "soft":
|
|
conv_server = model.get_layer("conv_server").trainable_weights
|
|
conv_client = model.get_layer("conv_client").trainable_weights
|
|
l1 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(conv_server, conv_client)]
|
|
model.add_loss(l1)
|
|
|
|
dense_server = model.get_layer("dense_server").trainable_weights
|
|
dense_client = model.get_layer("dense_client").trainable_weights
|
|
l2 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(dense_server, dense_client)]
|
|
model.add_loss(l2)
|
|
|
|
callbacks = [EarlyStopping(monitor='val_loss',
|
|
patience=5,
|
|
verbose=False)]
|
|
|
|
model.compile(optimizer='adam',
|
|
loss='binary_crossentropy',
|
|
metrics=['accuracy'])
|
|
|
|
history = model.fit(self.X,
|
|
self.y[0] if params["model_output"] == "client" else self.y,
|
|
batch_size=params["batch_size"],
|
|
epochs=n_iterations,
|
|
callbacks=callbacks,
|
|
shuffle=True,
|
|
validation_split=0.4)
|
|
|
|
return {"loss": np.min(history.history['val_loss']),
|
|
"early_stop": len(history.history["loss"]) < n_iterations,
|
|
"stop_after": len(history.history["val_loss"])}
|
|
|
|
# can be called multiple times
|
|
def run(self, skip_last=0, dry_run=False):
|
|
|
|
for s in reversed(range(self.s_max + 1)):
|
|
|
|
# initial number of configurations
|
|
n = int(ceil(self.B / self.max_iter / (s + 1) * self.eta ** s))
|
|
|
|
# initial number of iterations per config
|
|
r = self.max_iter * self.eta ** (-s)
|
|
|
|
# n random configurations
|
|
random_configs = [self.get_params() for _ in range(n)]
|
|
|
|
for i in range((s + 1) - int(skip_last)): # changed from s + 1
|
|
|
|
# Run each of the n configs for <iterations>
|
|
# and keep best (n_configs / eta) configurations
|
|
|
|
n_configs = n * self.eta ** (-i)
|
|
n_iterations = r * self.eta ** (i)
|
|
|
|
logger.info("*** {} configurations x {:.1f} iterations each".format(
|
|
n_configs, n_iterations))
|
|
|
|
val_losses = []
|
|
early_stops = []
|
|
|
|
for t in random_configs:
|
|
|
|
self.counter += 1
|
|
logger.info("Config {} | {} | lowest loss so far: {:.4f} (run {})".format(
|
|
self.counter, ctime(), self.best_loss, self.best_counter))
|
|
|
|
start_time = time()
|
|
|
|
if dry_run:
|
|
result = {'loss': rng(), 'log_loss': rng(), 'auc': rng()}
|
|
else:
|
|
result = self.try_params(n_iterations, t) # <---
|
|
|
|
assert (type(result) == dict)
|
|
assert ('loss' in result)
|
|
|
|
seconds = int(round(time() - start_time))
|
|
logger.info("{} seconds.".format(seconds))
|
|
|
|
loss = result['loss']
|
|
val_losses.append(loss)
|
|
|
|
early_stop = result.get('early_stop', False)
|
|
early_stops.append(early_stop)
|
|
|
|
# keeping track of the best result so far (for display only)
|
|
# could do it be checking results each time, but hey
|
|
if loss < self.best_loss:
|
|
self.best_loss = loss
|
|
self.best_counter = self.counter
|
|
|
|
result['counter'] = self.counter
|
|
result['seconds'] = seconds
|
|
result['params'] = t
|
|
result['iterations'] = n_iterations
|
|
|
|
self.results.append(result)
|
|
|
|
# select a number of best configurations for the next loop
|
|
# filter out early stops, if any
|
|
indices = np.argsort(val_losses)
|
|
random_configs = [random_configs[i] for i in indices if not early_stops[i]]
|
|
random_configs = random_configs[0:int(n_configs / self.eta)]
|
|
|
|
if self.savefile:
|
|
joblib.dump(self.results, self.savefile)
|
|
|
|
return self.results
|