124 lines
4.9 KiB
Python
124 lines
4.9 KiB
Python
import logging
|
|
|
|
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
|
|
|
|
import arguments
|
|
import dataset
|
|
import models
|
|
# create logger
|
|
import visualize
|
|
from arguments import get_model_args
|
|
from utils import exists_or_make_path, load_model
|
|
|
|
logger = logging.getLogger('cisco_logger')
|
|
|
|
args = arguments.parse()
|
|
|
|
|
|
def train_server_only(params):
|
|
logger.info(f"Create model path {args.model_path}")
|
|
exists_or_make_path(args.model_path)
|
|
logger.info(f"Use command line arguments: {args}")
|
|
|
|
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
|
|
args.data,
|
|
args.domain_length,
|
|
args.window)
|
|
domain_tr = domain_tr.value.reshape(-1, 40)
|
|
flow_tr = flow_tr.value.reshape(-1, 3)
|
|
server_tr = server_windows_tr.value.reshape(-1)
|
|
|
|
logger.info("define callbacks")
|
|
callbacks = []
|
|
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
|
monitor='loss',
|
|
verbose=False,
|
|
save_best_only=True))
|
|
callbacks.append(CSVLogger(args.train_log))
|
|
logger.info(f"Use early stopping: {args.stop_early}")
|
|
if args.stop_early:
|
|
callbacks.append(EarlyStopping(monitor='val_loss',
|
|
patience=5,
|
|
verbose=False))
|
|
custom_metrics = models.get_metric_functions()
|
|
|
|
model = models.get_server_model_by_params(params=params)
|
|
|
|
features = {"ipt_domains": domain_tr, "ipt_flows": flow_tr}
|
|
if args.model_output == "both":
|
|
labels = {"client": client_tr, "server": server_tr}
|
|
elif args.model_output == "client":
|
|
labels = {"client": client_tr}
|
|
elif args.model_output == "server":
|
|
labels = {"server": server_tr}
|
|
else:
|
|
raise ValueError("unknown model output")
|
|
|
|
logger.info("compile and train model")
|
|
logger.info(model.get_config())
|
|
model.compile(optimizer='adam',
|
|
loss='binary_crossentropy',
|
|
metrics=['accuracy'] + custom_metrics)
|
|
|
|
model.summary()
|
|
model.fit(features, labels,
|
|
batch_size=args.batch_size,
|
|
epochs=args.epochs,
|
|
callbacks=callbacks)
|
|
|
|
|
|
def test_server_only():
|
|
logger.info("start test: load data")
|
|
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
|
|
args.data,
|
|
args.domain_length,
|
|
args.window)
|
|
domain_val = domain_val.value.reshape(-1, 40)
|
|
flow_val = flow_val.value.reshape(-1, 3)
|
|
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
|
|
|
for model_args in get_model_args(args):
|
|
results = {}
|
|
logger.info(f"process model {model_args['model_path']}")
|
|
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects())
|
|
|
|
pred = clf_model.predict([domain_val, flow_val],
|
|
batch_size=args.batch_size,
|
|
verbose=1)
|
|
|
|
results["server_pred"] = pred
|
|
|
|
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
|
results["domain_embds"] = domain_embeddings
|
|
|
|
dataset.save_predictions(model_args["model_path"], results)
|
|
|
|
|
|
def vis_server():
|
|
def load_model(m, c):
|
|
from keras.models import load_model
|
|
clf = load_model(m, custom_objects=c)
|
|
emdb = clf.layers[1]
|
|
return emdb, clf
|
|
|
|
domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data(
|
|
args.data,
|
|
args.data,
|
|
args.domain_length,
|
|
args.window)
|
|
|
|
results = dataset.load_predictions(args.clf_model)
|
|
|
|
visualize.plot_clf()
|
|
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
|
visualize.plot_legend()
|
|
visualize.plot_save("results/server_model/windows_prc.pdf")
|
|
visualize.plot_clf()
|
|
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
|
visualize.plot_legend()
|
|
visualize.plot_save("results/server_model/windows_prc.pdf")
|
|
visualize.plot_clf()
|
|
visualize.plot_roc_curve(server_raw.flatten(), results["server_pred"].flatten(), "server")
|
|
visualize.plot_legend()
|
|
visualize.plot_save("results/server_model/windows_roc.pdf")
|