77 lines
2.6 KiB
Python
77 lines
2.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
# implementation of hyperband:
|
|
# https://arxiv.org/pdf/1603.06560.pdf
|
|
import numpy as np
|
|
|
|
|
|
def get_hyperparameter_configuration(configGenerator, n):
|
|
configurations = []
|
|
for i in np.arange(0, n, 1):
|
|
configurations.append(configGenerator())
|
|
return configurations
|
|
|
|
|
|
def run_then_return_val_loss(config, r_i, modelGenerator, trainData, trainLabel,
|
|
testData, testLabel):
|
|
# parameter
|
|
batch_size = 128
|
|
model = modelGenerator(config)
|
|
if model != None:
|
|
model.fit(x=trainData, y=trainLabel,
|
|
epochs=int(r_i), shuffle=True, initial_epoch=0,
|
|
batch_size=batch_size)
|
|
score = model.evaluate(testData, testLabel,
|
|
batch_size=batch_size)
|
|
score = score[0]
|
|
else:
|
|
score = np.infty
|
|
return score
|
|
|
|
|
|
def top_k(configurations, L, k):
|
|
outConfigs = []
|
|
sortIDs = np.argsort(np.array(L))
|
|
for i in np.arange(0, k, 1):
|
|
outConfigs.append(configurations[sortIDs[i]])
|
|
return outConfigs
|
|
|
|
|
|
def hyperband(R, nu, modelGenerator,
|
|
configGenerator,
|
|
trainData, trainLabel,
|
|
testData, testLabel,
|
|
outputFile=''):
|
|
allLosses = []
|
|
allConfigs = []
|
|
# input
|
|
|
|
# initialization
|
|
s_max = np.floor(np.log(R) / np.log(nu))
|
|
B = (s_max + 1) * R
|
|
|
|
for s in np.arange(s_max, -1, -1):
|
|
n = np.ceil(np.float(B) / np.float(R) * (np.float(np.power(nu, s)) / np.float(s + 1)))
|
|
r = np.float(R) * np.power(nu, -s)
|
|
configurations = get_hyperparameter_configuration(configGenerator, n)
|
|
for i in np.arange(0, s + 1, 1):
|
|
n_i = np.floor(np.float(n) * np.power(nu, -i))
|
|
r_i = np.float(r) * np.power(nu, i)
|
|
L = []
|
|
for config in configurations:
|
|
curLoss = run_then_return_val_loss(config, r_i, modelGenerator,
|
|
trainData, trainLabel,
|
|
testData, testLabel)
|
|
L.append(curLoss)
|
|
allLosses.append(curLoss)
|
|
allConfigs.append(config)
|
|
if outputFile != '':
|
|
with open(outputFile, 'a') as myfile:
|
|
myfile.write(str(config) + '\t' + str(curLoss) + \
|
|
'\t' + str(r_i) + '\n')
|
|
configurations = top_k(configurations, L, np.floor(np.float(n_i) / nu))
|
|
|
|
# print('n_i: ' + str(n_i))
|
|
# print('r_i: ' + str(r_i))
|
|
bestConfig = top_k(allConfigs, allLosses, 1)
|
|
return (bestConfig[0], allConfigs, allLosses)
|