ma_cisco_malware/main.py

412 lines
17 KiB
Python

import json
import logging
import os
import joblib
import numpy as np
import pandas as pd
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
from keras.models import load_model, Model
import arguments
import dataset
import hyperband
import models
# create logger
import visualize
from dataset import load_or_generate_h5data
from utils import exists_or_make_path, get_custom_class_weights
from arguments import get_model_args
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
# create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# add formatter to ch
ch.setFormatter(formatter)
# add ch to logger
logger.addHandler(ch)
ch = logging.FileHandler("info.log")
ch.setLevel(logging.DEBUG)
# create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# add formatter to ch
ch.setFormatter(formatter)
# add ch to logger
logger.addHandler(ch)
args = arguments.parse()
if args.gpu:
config = tf.ConfigProto(log_device_placement=True)
config.gpu_options.per_process_gpu_memory_fraction = 0.5
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
# default parameter
PARAMS = {
"type": args.model_type,
"depth": args.model_depth,
# "batch_size": 64,
"window_size": args.window,
"domain_length": args.domain_length,
"flow_features": 3,
#
'dropout': 0.5, # currently fix
'domain_features': args.domain_embedding,
'embedding_size': args.embedding,
'flow_features': 3,
'filter_embedding': args.filter_embedding,
'dense_embedding': args.dense_embedding,
'kernel_embedding': args.kernel_embedding,
'filter_main': args.filter_main,
'dense_main': args.dense_main,
'kernel_main': args.kernel_main,
'input_length': 40,
'model_output': args.model_output
}
def create_model(model, output_type):
if output_type == "both":
return Model(inputs=[model.in_domains, model.in_flows], outputs=(model.out_client, model.out_server))
elif output_type == "client":
return Model(inputs=[model.in_domains, model.in_flows], outputs=(model.out_client,))
else:
raise Exception("unknown model output")
def main_paul_best():
pauls_best_params = models.pauls_networks.best_config
main_train(pauls_best_params)
def main_hyperband():
params = {
# static params
"type": ["paul"],
"batch_size": [args.batch_size],
"window_size": [10],
"domain_length": [40],
"flow_features": [3],
"input_length": [40],
# model params
"embedding_size": [8, 16, 32, 64, 128, 256],
"filter_embedding": [8, 16, 32, 64, 128, 256],
"kernel_embedding": [1, 3, 5, 7, 9],
"hidden_embedding": [8, 16, 32, 64, 128, 256],
"dropout": [0.5],
"domain_features": [8, 16, 32, 64, 128, 256],
"filter_main": [8, 16, 32, 64, 128, 256],
"kernels_main": [1, 3, 5, 7, 9],
"dense_main": [8, 16, 32, 64, 128, 256],
}
logger.info("create training dataset")
domain_tr, flow_tr, name_tr, client_tr, server_tr = load_or_generate_h5data(args.train_h5data, args.train_data,
args.domain_length, args.window)
hp = hyperband.Hyperband(params,
[domain_tr, flow_tr],
[client_tr, server_tr])
results = hp.run()
json.dump(results, open("hyperband.json"))
def main_train(param=None):
logger.info(f"Create model path {args.model_path}")
exists_or_make_path(args.model_path)
logger.info(f"Use command line arguments: {args}")
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = load_or_generate_h5data(args.train_h5data,
args.train_data,
args.domain_length,
args.window)
logger.info("define callbacks")
callbacks = []
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
monitor='val_loss',
verbose=False,
save_best_only=True))
callbacks.append(CSVLogger(args.train_log))
logger.info(f"Use early stopping: {args.stop_early}")
if args.stop_early:
callbacks.append(EarlyStopping(monitor='val_loss',
patience=5,
verbose=False))
custom_metrics = models.get_metric_functions()
server_tr = np.max(server_windows_tr, axis=1)
if args.class_weights:
logger.info("class weights: compute custom weights")
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
logger.info(custom_class_weights)
else:
logger.info("class weights: set default")
custom_class_weights = None
logger.info(f"select model: {args.model_type}")
if args.model_type == "staggered":
if not param:
param = PARAMS
logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param)
model = create_model(new_model, args.model_output)
server_tr = np.expand_dims(server_windows_tr, 2)
logger.info("compile and train model")
embedding.summary()
model.summary()
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
loss_weights={"client": 0.0, "server": 1.0},
metrics=['accuracy'] + custom_metrics)
model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr},
{"client": client_tr, "server": server_tr},
batch_size=args.batch_size,
epochs=args.epochs,
shuffle=True,
validation_split=0.2,
class_weight=custom_class_weights)
model.get_layer("dense_server").trainable = False
model.get_layer("server").trainable = False
model.compile(optimizer='adam',
loss='binary_crossentropy',
loss_weights={"client": 1.0, "server": 0.0},
metrics=['accuracy'] + custom_metrics)
model.summary()
model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr},
{"client": client_tr, "server": server_tr},
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
shuffle=True,
validation_split=0.2,
class_weight=custom_class_weights)
else:
if not param:
param = PARAMS
logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param)
model = create_model(model, args.model_output)
new_model = create_model(new_model, args.model_output)
if args.model_type == "inter":
server_tr = np.expand_dims(server_windows_tr, 2)
model = new_model
logger.info("compile and train model")
embedding.summary()
model.summary()
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'] + custom_metrics)
if args.model_output == "both":
labels = [client_tr, server_tr]
elif args.model_output == "client":
labels = [client_tr]
elif args.model_output == "server":
labels = [server_tr]
else:
raise ValueError("unknown model output")
model.fit([domain_tr, flow_tr],
labels,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
shuffle=True,
validation_split=0.2,
class_weight=custom_class_weights)
logger.info("save embedding")
embedding.save(args.embedding_model)
def main_test():
logger.info("start test: load data")
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
args.test_data,
args.domain_length,
args.window)
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
for model_args in get_model_args(args):
results = {}
logger.info(f"process model {model_args['model_path']}")
clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics())
pred = clf_model.predict([domain_val, flow_val],
batch_size=args.batch_size,
verbose=1)
if args.model_output == "both":
c_pred, s_pred = pred
results["client_pred"] = c_pred
results["server_pred"] = s_pred
elif args.model_output == "client":
results["client_pred"] = pred
else:
results["server_pred"] = pred
# dataset.save_predictions(model_args["future_prediction"], c_pred, s_pred)
embd_model = load_model(model_args["embedding_model"])
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
# np.save(model_args["model_path"] + "/domain_embds.npy", domain_embeddings)
results["domain_embds"] = domain_embeddings
joblib.dump(results, model_args["model_path"] + "/results.joblib", compress=3)
def main_visualization():
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
args.test_data,
args.domain_length,
args.window)
# client_val, server_val = client_val.value, server_val.value
client_val = client_val.value
logger.info("plot model")
model = load_model(args.clf_model, custom_objects=models.get_metrics())
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
try:
logger.info("plot training curve")
logs = pd.read_csv(args.train_log)
if args.model_output == "client":
visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
else:
visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
except Exception as e:
logger.warning(f"could not generate training curves: {e}")
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
client_pred, server_pred = client_pred.value.flatten(), server_pred.value.flatten()
logger.info("plot pr curve")
visualize.plot_clf()
visualize.plot_precision_recall(client_val, client_pred, args.model_path)
visualize.plot_legend()
visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
# visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
# visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
# visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
logger.info("plot roc curve")
visualize.plot_clf()
visualize.plot_roc_curve(client_val, client_pred, args.model_path)
visualize.plot_legend()
visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}")
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_clf()
visualize.plot_precision_recall(user_vals, user_preds, args.model_path)
visualize.plot_legend()
visualize.plot_save("{}/user_client_prc.png".format(args.model_path))
visualize.plot_clf()
visualize.plot_roc_curve(user_vals, user_preds, args.model_path)
visualize.plot_legend()
visualize.plot_save("{}/user_client_roc.png".format(args.model_path))
visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(),
"{}/client_cov.png".format(args.model_path),
normalize=False, title="Client Confusion Matrix")
visualize.plot_confusion_matrix(user_vals, user_preds.flatten().round(),
"{}/user_cov.png".format(args.model_path),
normalize=False, title="User Confusion Matrix")
logger.info("visualize embedding")
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
domain_embedding = np.load(args.model_path + "/domain_embds.npy")
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
def main_visualize_all():
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
args.test_data,
args.domain_length,
args.window)
logger.info("plot pr curves")
visualize.plot_clf()
for model_args in get_model_args(args):
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"])
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
logger.info("plot roc curves")
visualize.plot_clf()
for model_args in get_model_args(args):
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"])
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
logger.info("plot user pr curves")
visualize.plot_clf()
for model_args in get_model_args(args):
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()})
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"])
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
logger.info("plot user roc curves")
visualize.plot_clf()
for model_args in get_model_args(args):
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()})
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"])
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
def main():
if "train" == args.mode:
main_train()
if "hyperband" == args.mode:
main_hyperband()
if "test" == args.mode:
main_test()
if "fancy" == args.mode:
main_visualization()
if "all_fancy" == args.mode:
main_visualize_all()
if "paul" == args.mode:
main_paul_best()
if __name__ == "__main__":
main()