ma_cisco_malware/main.py

251 lines
8.7 KiB
Python

import json
import logging
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
from keras.models import load_model
from sklearn.utils import class_weight
import arguments
import dataset
import hyperband
import models
# create logger
import visualize
from dataset import load_or_generate_h5data
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
# create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# add formatter to ch
ch.setFormatter(formatter)
# add ch to logger
logger.addHandler(ch)
ch = logging.FileHandler("info.log")
ch.setLevel(logging.DEBUG)
# create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# add formatter to ch
ch.setFormatter(formatter)
# add ch to logger
logger.addHandler(ch)
print = logger.info
args = arguments.parse()
if args.gpu:
config = tf.ConfigProto(log_device_placement=True)
config.gpu_options.per_process_gpu_memory_fraction = 0.5
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
def exists_or_make_path(p):
if not os.path.exists(p):
os.makedirs(p)
def main_paul_best():
char_dict = dataset.get_character_dict()
pauls_best_params = models.pauls_networks.best_config
pauls_best_params["vocab_size"] = len(char_dict) + 1
main_train(pauls_best_params)
def main_hyperband():
char_dict = dataset.get_character_dict()
params = {
# static params
"type": ["paul"],
"batch_size": [args.batch_size],
"vocab_size": [len(char_dict) + 1],
"window_size": [10],
"domain_length": [40],
"flow_features": [3],
"input_length": [40],
# model params
"embedding_size": [16, 32, 64, 128, 256, 512],
"filter_embedding": [16, 32, 64, 128, 256, 512],
"kernel_embedding": [1, 3, 5, 7, 9],
"hidden_embedding": [16, 32, 64, 128, 256, 512],
"dropout": [0.5],
"domain_features": [16, 32, 64, 128, 256, 512],
"filter_main": [16, 32, 64, 128, 256, 512],
"kernels_main": [1, 3, 5, 7, 9],
"dense_main": [16, 32, 64, 128, 256, 512],
}
logger.info("create training dataset")
domain_tr, flow_tr, client_tr, server_tr = load_or_generate_h5data(args.train_h5data, args.train_data,
args.domain_length, args.window)
hp = hyperband.Hyperband(params,
[domain_tr, flow_tr],
[client_tr, server_tr])
results = hp.run()
json.dump(results, open("hyperband.json"))
def get_custom_class_weights(client_tr, server_tr):
client = client_tr.value.argmax(1)
server = server_tr.value.argmax(1)
client_class_weight = class_weight.compute_class_weight('balanced', np.unique(client), client)
server_class_weight = class_weight.compute_class_weight('balanced', np.unique(server), server)
return {
"client": client_class_weight,
"server": server_class_weight
}
def main_train(param=None):
exists_or_make_path(args.model_path)
char_dict = dataset.get_character_dict()
domain_tr, flow_tr, client_tr, server_tr = load_or_generate_h5data(args.train_h5data, args.train_data,
args.domain_length, args.window)
# parameter
p = {
"type": "paul",
"batch_size": 64,
"window_size": args.window,
"domain_length": args.domain_length,
"flow_features": 3,
"vocab_size": len(char_dict) + 1,
#
'dropout': 0.5,
'domain_features': args.domain_embedding,
'embedding_size': args.embedding,
'filter_main': 128,
'flow_features': 3,
'dense_main': 512,
'filter_embedding': args.hidden_char_dims,
'hidden_embedding': args.domain_embedding,
'kernel_embedding': 3,
'kernels_main': 3,
'input_length': 40
}
if not param:
param = p
embedding, model = models.get_models_by_params(param)
embedding.summary()
model.summary()
logger.info("define callbacks")
callbacks = []
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
monitor='val_loss',
verbose=False,
save_best_only=True))
callbacks.append(CSVLogger(args.train_log))
if args.stop_early:
callbacks.append(EarlyStopping(monitor='val_loss',
patience=5,
verbose=False))
logger.info("compile model")
custom_metrics = models.get_metric_functions()
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'] + custom_metrics)
if args.class_weights:
logger.info("class weights: compute custom weights")
custom_class_weights = get_custom_class_weights(client_tr, server_tr)
else:
logger.info("class weights: set default")
custom_class_weights = None
logger.info("start training")
model.fit([domain_tr, flow_tr],
[client_tr, server_tr],
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
shuffle=True,
validation_split=0.2,
class_weight=custom_class_weights)
logger.info("save embedding")
embedding.save(args.embedding_model)
def main_test():
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
args.domain_length, args.window)
clf = load_model(args.clf_model, custom_objects=models.get_metrics())
# stats = clf.evaluate([domain_val, flow_val],
# [client_val, server_val],
# batch_size=args.batch_size)
y_pred = clf.predict([domain_val, flow_val],
batch_size=args.batch_size)
np.save(args.future_prediction, y_pred)
def main_visualization():
_, _, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
args.domain_length, args.window)
logger.info("plot model")
model = load_model(args.clf_model, custom_objects=models.get_metrics())
visualize.plot_model(model, args.model_path + "model.png")
logger.info("plot training curve")
logs = pd.read_csv(args.train_log)
visualize.plot_training_curve(logs, "client", "{}/client_train.png".format(args.model_path))
visualize.plot_training_curve(logs, "server", "{}/server_train.png".format(args.model_path))
client_pred, server_pred = np.load(args.future_prediction)
logger.info("plot pr curve")
visualize.plot_precision_recall(client_val.value, client_pred, "{}/client_prc.png".format(args.model_path))
visualize.plot_precision_recall(server_val.value, server_pred, "{}/server_prc.png".format(args.model_path))
visualize.plot_precision_recall_curves(client_val.value, client_pred, "{}/client_prc2.png".format(args.model_path))
visualize.plot_precision_recall_curves(server_val.value, server_pred, "{}/server_prc2.png".format(args.model_path))
logger.info("plot roc curve")
visualize.plot_roc_curve(client_val.value, client_pred, "{}/client_roc.png".format(args.model_path))
visualize.plot_roc_curve(server_val.value, server_pred, "{}/server_roc.png".format(args.model_path))
visualize.plot_confusion_matrix(client_val.value.argmax(1), client_pred.argmax(1),
"{}/client_cov.png".format(args.model_path),
normalize=False, title="Client Confusion Matrix")
visualize.plot_confusion_matrix(server_val.value.argmax(1), server_pred.argmax(1),
"{}/server_cov.png".format(args.model_path),
normalize=False, title="Server Confusion Matrix")
def main_score():
# mask = dataset.load_mask_eval(args.data, args.test_image)
# pred = np.load(args.pred)
# visualize.score_model(mask, pred)
pass
def main():
if "train" in args.modes:
main_train()
if "hyperband" in args.modes:
main_hyperband()
if "test" in args.modes:
main_test()
if "fancy" in args.modes:
main_visualization()
if "score" in args.modes:
main_score()
if "paul" in args.modes:
main_paul_best()
if __name__ == "__main__":
main()