ma_cisco_malware/hyperband.py

153 lines
5.3 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
# implementation of hyperband:
# https://arxiv.org/pdf/1603.06560.pdf
import logging
2017-07-07 16:48:10 +02:00
import random
2017-09-29 22:59:57 +02:00
from math import ceil, log
2017-07-07 16:48:10 +02:00
from random import random as rng
2017-09-29 22:59:57 +02:00
from time import ctime, time
2017-07-07 16:48:10 +02:00
import joblib
import numpy as np
2017-09-29 22:59:57 +02:00
from keras.callbacks import EarlyStopping
2017-07-07 16:48:10 +02:00
import models
2017-09-29 22:59:57 +02:00
from main import create_model
2017-07-07 16:48:10 +02:00
logger = logging.getLogger('logger')
2017-07-07 16:48:10 +02:00
def sample_params(param_distribution: dict):
p = {}
for key, val in param_distribution.items():
p[key] = random.choice(val)
return p
class Hyperband:
def __init__(self, param_distribution, X, y, max_iter=81, savefile=None):
2017-07-07 16:48:10 +02:00
self.get_params = lambda: sample_params(param_distribution)
2017-09-29 22:59:57 +02:00
self.max_iter = max_iter # maximum iterations per configuration
2017-07-07 16:48:10 +02:00
self.eta = 3 # defines configuration downsampling rate (default = 3)
self.logeta = lambda x: log(x) / log(self.eta)
self.s_max = int(self.logeta(self.max_iter))
self.B = (self.s_max + 1) * self.max_iter
self.results = [] # list of dicts
self.counter = 0
self.best_loss = np.inf
self.best_counter = -1
self.savefile = savefile
2017-07-07 16:48:10 +02:00
self.X = X
self.y = y
2017-09-29 22:59:57 +02:00
2017-07-07 16:48:10 +02:00
def try_params(self, n_iterations, params):
n_iterations = int(round(n_iterations))
2017-09-29 22:59:57 +02:00
embedding, model, new_model = models.get_models_by_params(params)
2017-09-29 23:34:39 +02:00
model = create_model(model, params["model_output"])
new_model = create_model(new_model, params["model_output"])
2017-09-29 22:59:57 +02:00
if params["type"] in ("inter", "staggered"):
model = new_model
callbacks = [EarlyStopping(monitor='val_loss',
patience=5,
verbose=False)]
2017-07-07 16:48:10 +02:00
model.compile(optimizer='adam',
2017-09-29 22:59:57 +02:00
loss='binary_crossentropy',
2017-07-07 16:48:10 +02:00
metrics=['accuracy'])
history = model.fit(self.X,
self.y,
batch_size=params["batch_size"],
epochs=n_iterations,
2017-09-29 22:59:57 +02:00
callbacks=callbacks,
2017-07-07 16:48:10 +02:00
shuffle=True,
2017-09-29 22:59:57 +02:00
validation_split=0.4)
2017-07-07 16:48:10 +02:00
2017-09-29 23:34:39 +02:00
return {"loss": history.history['val_loss'][-1],
"early_stop": len(history.history["loss"]) < n_iterations}
2017-09-29 22:59:57 +02:00
2017-07-07 16:48:10 +02:00
# can be called multiple times
def run(self, skip_last=0, dry_run=False):
2017-09-29 22:59:57 +02:00
2017-07-07 16:48:10 +02:00
for s in reversed(range(self.s_max + 1)):
2017-09-29 22:59:57 +02:00
2017-07-07 16:48:10 +02:00
# initial number of configurations
n = int(ceil(self.B / self.max_iter / (s + 1) * self.eta ** s))
2017-09-29 22:59:57 +02:00
2017-07-07 16:48:10 +02:00
# initial number of iterations per config
r = self.max_iter * self.eta ** (-s)
2017-09-29 22:59:57 +02:00
2017-07-07 16:48:10 +02:00
# n random configurations
2017-09-29 22:59:57 +02:00
random_configs = [self.get_params() for _ in range(n)]
2017-07-07 16:48:10 +02:00
for i in range((s + 1) - int(skip_last)): # changed from s + 1
2017-09-29 22:59:57 +02:00
2017-07-07 16:48:10 +02:00
# Run each of the n configs for <iterations>
# and keep best (n_configs / eta) configurations
2017-09-29 22:59:57 +02:00
2017-07-07 16:48:10 +02:00
n_configs = n * self.eta ** (-i)
n_iterations = r * self.eta ** (i)
2017-09-29 22:59:57 +02:00
logger.info("\n*** {} configurations x {:.1f} iterations each".format(
2017-09-29 22:59:57 +02:00
n_configs, n_iterations))
2017-07-07 16:48:10 +02:00
val_losses = []
early_stops = []
2017-09-29 22:59:57 +02:00
for t in random_configs:
2017-07-07 16:48:10 +02:00
self.counter += 1
logger.info("\n{} | {} | lowest loss so far: {:.4f} (run {})\n".format(
2017-09-29 22:59:57 +02:00
self.counter, ctime(), self.best_loss, self.best_counter))
2017-07-07 16:48:10 +02:00
start_time = time()
if dry_run:
result = {'loss': rng(), 'log_loss': rng(), 'auc': rng()}
else:
result = self.try_params(n_iterations, t) # <---
assert (type(result) == dict)
assert ('loss' in result)
seconds = int(round(time() - start_time))
logger.info("\n{} seconds.".format(seconds))
2017-07-07 16:48:10 +02:00
loss = result['loss']
val_losses.append(loss)
early_stop = result.get('early_stop', False)
early_stops.append(early_stop)
# keeping track of the best result so far (for display only)
# could do it be checking results each time, but hey
if loss < self.best_loss:
self.best_loss = loss
self.best_counter = self.counter
result['counter'] = self.counter
result['seconds'] = seconds
result['params'] = t
result['iterations'] = n_iterations
self.results.append(result)
2017-09-29 22:59:57 +02:00
2017-07-07 16:48:10 +02:00
# select a number of best configurations for the next loop
# filter out early stops, if any
indices = np.argsort(val_losses)
2017-09-29 22:59:57 +02:00
random_configs = [random_configs[i] for i in indices if not early_stops[i]]
random_configs = random_configs[0:int(n_configs / self.eta)]
if self.savefile:
joblib.dump(self.results, self.savefile)
2017-07-07 16:48:10 +02:00
return self.results