refactor server training into separate file; add additional info to hyperband log

This commit is contained in:
René Knaebel 2017-10-19 17:37:29 +02:00
parent d1da3d6ca3
commit a860f0da34
4 changed files with 137 additions and 14 deletions

View File

@ -14,7 +14,7 @@ from keras.callbacks import EarlyStopping
import models
from main import create_model
logger = logging.getLogger('logger')
logger = logging.getLogger('cisco_logger')
def sample_params(param_distribution: dict):
@ -72,7 +72,8 @@ class Hyperband:
validation_split=0.4)
return {"loss": np.min(history.history['val_loss']),
"early_stop": len(history.history["loss"]) < n_iterations}
"early_stop": len(history.history["loss"]) < n_iterations,
"stop_after": len(history.history["val_loss"])}
# can be called multiple times
def run(self, skip_last=0, dry_run=False):
@ -96,7 +97,7 @@ class Hyperband:
n_configs = n * self.eta ** (-i)
n_iterations = r * self.eta ** (i)
logger.info("\n*** {} configurations x {:.1f} iterations each".format(
logger.info("*** {} configurations x {:.1f} iterations each".format(
n_configs, n_iterations))
val_losses = []
@ -105,7 +106,7 @@ class Hyperband:
for t in random_configs:
self.counter += 1
logger.info("\n{} | {} | lowest loss so far: {:.4f} (run {})\n".format(
logger.info("Config {} | {} | lowest loss so far: {:.4f} (run {})".format(
self.counter, ctime(), self.best_loss, self.best_counter))
start_time = time()
@ -119,7 +120,7 @@ class Hyperband:
assert ('loss' in result)
seconds = int(round(time() - start_time))
logger.info("\n{} seconds.".format(seconds))
logger.info("{} seconds.".format(seconds))
loss = result['loss']
val_losses.append(loss)

View File

@ -10,11 +10,11 @@ Model = namedtuple("Model", ["in_domains", "in_flows", "out_client", "out_server
def get_embedding(embedding_size, input_length, filter_size, kernel_size, hidden_dims, drop_out=0.5):
x = y = Input(shape=(input_length,))
y = Embedding(input_dim=dataset.get_vocab_size(), output_dim=embedding_size)(y)
y = Conv1D(filter_size, kernel_size=kernel_size, activation="relu")(y)
y = Conv1D(filter_size, kernel_size=3, activation="relu")(y)
y = Conv1D(filter_size, kernel_size=3, activation="relu")(y)
x = Input(shape=(input_length,))
y = Embedding(input_dim=dataset.get_vocab_size(), output_dim=embedding_size)(x)
y = Conv1D(filter_size, kernel_size=kernel_size, activation="relu", padding="same")(y)
y = Conv1D(filter_size, kernel_size=3, activation="relu", padding="same")(y)
y = Conv1D(filter_size, kernel_size=3, activation="relu", padding="same")(y)
y = GlobalAveragePooling1D()(y)
y = Dense(hidden_dims, activation="relu")(y)
return KerasModel(x, y)

View File

@ -2,7 +2,7 @@ from collections import namedtuple
import keras
from keras.engine import Input, Model as KerasModel
from keras.layers import Activation, Conv1D, Dense, Dropout, Embedding, GlobalMaxPooling1D, TimeDistributed
from keras.layers import Conv1D, Dense, Dropout, Embedding, GlobalMaxPooling1D, TimeDistributed
import dataset
@ -38,8 +38,7 @@ def get_embedding(embedding_size, input_length, filter_size, kernel_size, hidden
activation='relu')(y)
y = GlobalMaxPooling1D()(y)
y = Dropout(drop_out)(y)
y = Dense(hidden_dims)(y)
y = Activation('relu')(y)
y = Dense(hidden_dims, activation="relu")(y)
return KerasModel(x, y)

123
server.py Normal file
View File

@ -0,0 +1,123 @@
import logging
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
import arguments
import dataset
import models
# create logger
import visualize
from arguments import get_model_args
from utils import exists_or_make_path, load_model
logger = logging.getLogger('cisco_logger')
args = arguments.parse()
def train_server_only(params):
logger.info(f"Create model path {args.model_path}")
exists_or_make_path(args.model_path)
logger.info(f"Use command line arguments: {args}")
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
args.data,
args.domain_length,
args.window)
domain_tr = domain_tr.value.reshape(-1, 40)
flow_tr = flow_tr.value.reshape(-1, 3)
server_tr = server_windows_tr.value.reshape(-1)
logger.info("define callbacks")
callbacks = []
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
monitor='loss',
verbose=False,
save_best_only=True))
callbacks.append(CSVLogger(args.train_log))
logger.info(f"Use early stopping: {args.stop_early}")
if args.stop_early:
callbacks.append(EarlyStopping(monitor='val_loss',
patience=5,
verbose=False))
custom_metrics = models.get_metric_functions()
model = models.get_server_model_by_params(params=params)
features = {"ipt_domains": domain_tr, "ipt_flows": flow_tr}
if args.model_output == "both":
labels = {"client": client_tr, "server": server_tr}
elif args.model_output == "client":
labels = {"client": client_tr}
elif args.model_output == "server":
labels = {"server": server_tr}
else:
raise ValueError("unknown model output")
logger.info("compile and train model")
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'] + custom_metrics)
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks)
def test_server_only():
logger.info("start test: load data")
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
args.data,
args.domain_length,
args.window)
domain_val = domain_val.value.reshape(-1, 40)
flow_val = flow_val.value.reshape(-1, 3)
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
for model_args in get_model_args(args):
results = {}
logger.info(f"process model {model_args['model_path']}")
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects())
pred = clf_model.predict([domain_val, flow_val],
batch_size=args.batch_size,
verbose=1)
results["server_pred"] = pred
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
results["domain_embds"] = domain_embeddings
dataset.save_predictions(model_args["model_path"], results)
def vis_server():
def load_model(m, c):
from keras.models import load_model
clf = load_model(m, custom_objects=c)
emdb = clf.layers[1]
return emdb, clf
domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data(
args.data,
args.data,
args.domain_length,
args.window)
results = dataset.load_predictions(args.clf_model)
visualize.plot_clf()
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
visualize.plot_legend()
visualize.plot_save("results/server_model/windows_prc.pdf")
visualize.plot_clf()
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
visualize.plot_legend()
visualize.plot_save("results/server_model/windows_prc.pdf")
visualize.plot_clf()
visualize.plot_roc_curve(server_raw.flatten(), results["server_pred"].flatten(), "server")
visualize.plot_legend()
visualize.plot_save("results/server_model/windows_roc.pdf")