refactor server training into separate file; add additional info to hyperband log
This commit is contained in:
parent
d1da3d6ca3
commit
a860f0da34
11
hyperband.py
11
hyperband.py
@ -14,7 +14,7 @@ from keras.callbacks import EarlyStopping
|
||||
import models
|
||||
from main import create_model
|
||||
|
||||
logger = logging.getLogger('logger')
|
||||
logger = logging.getLogger('cisco_logger')
|
||||
|
||||
|
||||
def sample_params(param_distribution: dict):
|
||||
@ -72,7 +72,8 @@ class Hyperband:
|
||||
validation_split=0.4)
|
||||
|
||||
return {"loss": np.min(history.history['val_loss']),
|
||||
"early_stop": len(history.history["loss"]) < n_iterations}
|
||||
"early_stop": len(history.history["loss"]) < n_iterations,
|
||||
"stop_after": len(history.history["val_loss"])}
|
||||
|
||||
# can be called multiple times
|
||||
def run(self, skip_last=0, dry_run=False):
|
||||
@ -96,7 +97,7 @@ class Hyperband:
|
||||
n_configs = n * self.eta ** (-i)
|
||||
n_iterations = r * self.eta ** (i)
|
||||
|
||||
logger.info("\n*** {} configurations x {:.1f} iterations each".format(
|
||||
logger.info("*** {} configurations x {:.1f} iterations each".format(
|
||||
n_configs, n_iterations))
|
||||
|
||||
val_losses = []
|
||||
@ -105,7 +106,7 @@ class Hyperband:
|
||||
for t in random_configs:
|
||||
|
||||
self.counter += 1
|
||||
logger.info("\n{} | {} | lowest loss so far: {:.4f} (run {})\n".format(
|
||||
logger.info("Config {} | {} | lowest loss so far: {:.4f} (run {})".format(
|
||||
self.counter, ctime(), self.best_loss, self.best_counter))
|
||||
|
||||
start_time = time()
|
||||
@ -119,7 +120,7 @@ class Hyperband:
|
||||
assert ('loss' in result)
|
||||
|
||||
seconds = int(round(time() - start_time))
|
||||
logger.info("\n{} seconds.".format(seconds))
|
||||
logger.info("{} seconds.".format(seconds))
|
||||
|
||||
loss = result['loss']
|
||||
val_losses.append(loss)
|
||||
|
@ -10,11 +10,11 @@ Model = namedtuple("Model", ["in_domains", "in_flows", "out_client", "out_server
|
||||
|
||||
|
||||
def get_embedding(embedding_size, input_length, filter_size, kernel_size, hidden_dims, drop_out=0.5):
|
||||
x = y = Input(shape=(input_length,))
|
||||
y = Embedding(input_dim=dataset.get_vocab_size(), output_dim=embedding_size)(y)
|
||||
y = Conv1D(filter_size, kernel_size=kernel_size, activation="relu")(y)
|
||||
y = Conv1D(filter_size, kernel_size=3, activation="relu")(y)
|
||||
y = Conv1D(filter_size, kernel_size=3, activation="relu")(y)
|
||||
x = Input(shape=(input_length,))
|
||||
y = Embedding(input_dim=dataset.get_vocab_size(), output_dim=embedding_size)(x)
|
||||
y = Conv1D(filter_size, kernel_size=kernel_size, activation="relu", padding="same")(y)
|
||||
y = Conv1D(filter_size, kernel_size=3, activation="relu", padding="same")(y)
|
||||
y = Conv1D(filter_size, kernel_size=3, activation="relu", padding="same")(y)
|
||||
y = GlobalAveragePooling1D()(y)
|
||||
y = Dense(hidden_dims, activation="relu")(y)
|
||||
return KerasModel(x, y)
|
||||
|
@ -2,7 +2,7 @@ from collections import namedtuple
|
||||
|
||||
import keras
|
||||
from keras.engine import Input, Model as KerasModel
|
||||
from keras.layers import Activation, Conv1D, Dense, Dropout, Embedding, GlobalMaxPooling1D, TimeDistributed
|
||||
from keras.layers import Conv1D, Dense, Dropout, Embedding, GlobalMaxPooling1D, TimeDistributed
|
||||
|
||||
import dataset
|
||||
|
||||
@ -38,8 +38,7 @@ def get_embedding(embedding_size, input_length, filter_size, kernel_size, hidden
|
||||
activation='relu')(y)
|
||||
y = GlobalMaxPooling1D()(y)
|
||||
y = Dropout(drop_out)(y)
|
||||
y = Dense(hidden_dims)(y)
|
||||
y = Activation('relu')(y)
|
||||
y = Dense(hidden_dims, activation="relu")(y)
|
||||
return KerasModel(x, y)
|
||||
|
||||
|
||||
|
123
server.py
Normal file
123
server.py
Normal file
@ -0,0 +1,123 @@
|
||||
import logging
|
||||
|
||||
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
|
||||
|
||||
import arguments
|
||||
import dataset
|
||||
import models
|
||||
# create logger
|
||||
import visualize
|
||||
from arguments import get_model_args
|
||||
from utils import exists_or_make_path, load_model
|
||||
|
||||
logger = logging.getLogger('cisco_logger')
|
||||
|
||||
args = arguments.parse()
|
||||
|
||||
|
||||
def train_server_only(params):
|
||||
logger.info(f"Create model path {args.model_path}")
|
||||
exists_or_make_path(args.model_path)
|
||||
logger.info(f"Use command line arguments: {args}")
|
||||
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
|
||||
args.data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
domain_tr = domain_tr.value.reshape(-1, 40)
|
||||
flow_tr = flow_tr.value.reshape(-1, 3)
|
||||
server_tr = server_windows_tr.value.reshape(-1)
|
||||
|
||||
logger.info("define callbacks")
|
||||
callbacks = []
|
||||
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
||||
monitor='loss',
|
||||
verbose=False,
|
||||
save_best_only=True))
|
||||
callbacks.append(CSVLogger(args.train_log))
|
||||
logger.info(f"Use early stopping: {args.stop_early}")
|
||||
if args.stop_early:
|
||||
callbacks.append(EarlyStopping(monitor='val_loss',
|
||||
patience=5,
|
||||
verbose=False))
|
||||
custom_metrics = models.get_metric_functions()
|
||||
|
||||
model = models.get_server_model_by_params(params=params)
|
||||
|
||||
features = {"ipt_domains": domain_tr, "ipt_flows": flow_tr}
|
||||
if args.model_output == "both":
|
||||
labels = {"client": client_tr, "server": server_tr}
|
||||
elif args.model_output == "client":
|
||||
labels = {"client": client_tr}
|
||||
elif args.model_output == "server":
|
||||
labels = {"server": server_tr}
|
||||
else:
|
||||
raise ValueError("unknown model output")
|
||||
|
||||
logger.info("compile and train model")
|
||||
logger.info(model.get_config())
|
||||
model.compile(optimizer='adam',
|
||||
loss='binary_crossentropy',
|
||||
metrics=['accuracy'] + custom_metrics)
|
||||
|
||||
model.summary()
|
||||
model.fit(features, labels,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
callbacks=callbacks)
|
||||
|
||||
|
||||
def test_server_only():
|
||||
logger.info("start test: load data")
|
||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
|
||||
args.data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
domain_val = domain_val.value.reshape(-1, 40)
|
||||
flow_val = flow_val.value.reshape(-1, 3)
|
||||
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
||||
|
||||
for model_args in get_model_args(args):
|
||||
results = {}
|
||||
logger.info(f"process model {model_args['model_path']}")
|
||||
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects())
|
||||
|
||||
pred = clf_model.predict([domain_val, flow_val],
|
||||
batch_size=args.batch_size,
|
||||
verbose=1)
|
||||
|
||||
results["server_pred"] = pred
|
||||
|
||||
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||
results["domain_embds"] = domain_embeddings
|
||||
|
||||
dataset.save_predictions(model_args["model_path"], results)
|
||||
|
||||
|
||||
def vis_server():
|
||||
def load_model(m, c):
|
||||
from keras.models import load_model
|
||||
clf = load_model(m, custom_objects=c)
|
||||
emdb = clf.layers[1]
|
||||
return emdb, clf
|
||||
|
||||
domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data(
|
||||
args.data,
|
||||
args.data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
|
||||
results = dataset.load_predictions(args.clf_model)
|
||||
|
||||
visualize.plot_clf()
|
||||
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("results/server_model/windows_prc.pdf")
|
||||
visualize.plot_clf()
|
||||
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("results/server_model/windows_prc.pdf")
|
||||
visualize.plot_clf()
|
||||
visualize.plot_roc_curve(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("results/server_model/windows_roc.pdf")
|
Loading…
Reference in New Issue
Block a user