import logging from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint import arguments import dataset import models # create logger import visualize from arguments import get_model_args from utils import exists_or_make_path, load_model logger = logging.getLogger('cisco_logger') args = arguments.parse() def train_server_only(params): logger.info(f"Create model path {args.model_path}") exists_or_make_path(args.model_path) logger.info(f"Use command line arguments: {args}") domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data, args.data, args.domain_length, args.window) domain_tr = domain_tr.value.reshape(-1, 40) flow_tr = flow_tr.value.reshape(-1, 3) server_tr = server_windows_tr.value.reshape(-1) logger.info("define callbacks") callbacks = [] callbacks.append(ModelCheckpoint(filepath=args.clf_model, monitor='loss', verbose=False, save_best_only=True)) callbacks.append(CSVLogger(args.train_log)) logger.info(f"Use early stopping: {args.stop_early}") if args.stop_early: callbacks.append(EarlyStopping(monitor='val_loss', patience=5, verbose=False)) custom_metrics = models.get_metric_functions() model = models.get_server_model_by_params(params=params) features = {"ipt_domains": domain_tr, "ipt_flows": flow_tr} if args.model_output == "both": labels = {"client": client_tr, "server": server_tr} elif args.model_output == "client": labels = {"client": client_tr} elif args.model_output == "server": labels = {"server": server_tr} else: raise ValueError("unknown model output") logger.info("compile and train model") logger.info(model.get_config()) model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'] + custom_metrics) model.summary() model.fit(features, labels, batch_size=args.batch_size, epochs=args.epochs, callbacks=callbacks) def test_server_only(): logger.info("start test: load data") domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.data, args.domain_length, args.window) domain_val = domain_val.value.reshape(-1, 40) flow_val = flow_val.value.reshape(-1, 3) domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length) for model_args in get_model_args(args): results = {} logger.info(f"process model {model_args['model_path']}") embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects()) pred = clf_model.predict([domain_val, flow_val], batch_size=args.batch_size, verbose=1) results["server_pred"] = pred domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1) results["domain_embds"] = domain_embeddings dataset.save_predictions(model_args["model_path"], results) def vis_server(): def load_model(m, c): from keras.models import load_model clf = load_model(m, custom_objects=c) emdb = clf.layers[1] return emdb, clf domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data( args.data, args.data, args.domain_length, args.window) results = dataset.load_predictions(args.clf_model) visualize.plot_clf() visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server") visualize.plot_legend() visualize.plot_save("results/server_model/windows_prc.pdf") visualize.plot_clf() visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server") visualize.plot_legend() visualize.plot_save("results/server_model/windows_prc.pdf") visualize.plot_clf() visualize.plot_roc_curve(server_raw.flatten(), results["server_pred"].flatten(), "server") visualize.plot_legend() visualize.plot_save("results/server_model/windows_roc.pdf")