ma_cisco_malware/server.py

124 lines
4.9 KiB
Python
Raw Normal View History

import logging
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
import arguments
import dataset
import models
# create logger
import visualize
from arguments import get_model_args
from utils import exists_or_make_path, load_model
logger = logging.getLogger('cisco_logger')
args = arguments.parse()
def train_server_only(params):
logger.info(f"Create model path {args.model_path}")
exists_or_make_path(args.model_path)
logger.info(f"Use command line arguments: {args}")
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
args.data,
args.domain_length,
args.window)
domain_tr = domain_tr.value.reshape(-1, 40)
flow_tr = flow_tr.value.reshape(-1, 3)
server_tr = server_windows_tr.value.reshape(-1)
logger.info("define callbacks")
callbacks = []
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
monitor='loss',
verbose=False,
save_best_only=True))
callbacks.append(CSVLogger(args.train_log))
logger.info(f"Use early stopping: {args.stop_early}")
if args.stop_early:
callbacks.append(EarlyStopping(monitor='val_loss',
patience=5,
verbose=False))
custom_metrics = models.get_metric_functions()
model = models.get_server_model_by_params(params=params)
features = {"ipt_domains": domain_tr, "ipt_flows": flow_tr}
if args.model_output == "both":
labels = {"client": client_tr, "server": server_tr}
elif args.model_output == "client":
labels = {"client": client_tr}
elif args.model_output == "server":
labels = {"server": server_tr}
else:
raise ValueError("unknown model output")
logger.info("compile and train model")
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'] + custom_metrics)
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks)
def test_server_only():
logger.info("start test: load data")
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
args.data,
args.domain_length,
args.window)
domain_val = domain_val.value.reshape(-1, 40)
flow_val = flow_val.value.reshape(-1, 3)
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
for model_args in get_model_args(args):
results = {}
logger.info(f"process model {model_args['model_path']}")
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects())
pred = clf_model.predict([domain_val, flow_val],
batch_size=args.batch_size,
verbose=1)
results["server_pred"] = pred
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
results["domain_embds"] = domain_embeddings
dataset.save_predictions(model_args["model_path"], results)
def vis_server():
def load_model(m, c):
from keras.models import load_model
clf = load_model(m, custom_objects=c)
emdb = clf.layers[1]
return emdb, clf
domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data(
args.data,
args.data,
args.domain_length,
args.window)
results = dataset.load_predictions(args.clf_model)
visualize.plot_clf()
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
visualize.plot_legend()
visualize.plot_save("results/server_model/windows_prc.pdf")
visualize.plot_clf()
visualize.plot_precision_recall(server_raw.flatten(), results["server_pred"].flatten(), "server")
visualize.plot_legend()
visualize.plot_save("results/server_model/windows_prc.pdf")
visualize.plot_clf()
visualize.plot_roc_curve(server_raw.flatten(), results["server_pred"].flatten(), "server")
visualize.plot_legend()
visualize.plot_save("results/server_model/windows_roc.pdf")