refactor dataset generation, add callbacks
This commit is contained in:
parent
a196daa895
commit
9f0bae33d5
57
dataset.py
57
dataset.py
@ -25,6 +25,7 @@ def encode_char(c):
|
||||
encode_char = np.vectorize(encode_char)
|
||||
|
||||
|
||||
# TODO: refactor
|
||||
def get_user_chunks(dataFrame, windowSize=10, overlapping=False,
|
||||
maxLengthInSeconds=300):
|
||||
maxMilliSeconds = maxLengthInSeconds * 1000
|
||||
@ -66,32 +67,17 @@ def get_user_chunks(dataFrame, windowSize=10, overlapping=False,
|
||||
if len(outDomainLists[-1]) != windowSize:
|
||||
outDomainLists.pop(-1)
|
||||
outDFFrames.pop(-1)
|
||||
return (outDomainLists, outDFFrames)
|
||||
return outDomainLists, outDFFrames
|
||||
|
||||
|
||||
def get_domain_features(domain, vocab, max_length=40):
|
||||
def get_domain_features(domain, vocab: dict, max_length=40):
|
||||
encoding = np.zeros((max_length,))
|
||||
for j in range(np.min([len(domain), max_length])):
|
||||
char = domain[-j]
|
||||
if char in vocab:
|
||||
encoding[j] = vocab[char]
|
||||
for j in range(min(len(domain), max_length)):
|
||||
char = domain[-j] # TODO: why -j -> order reversed for domain url?
|
||||
encoding[j] = vocab.get(char, 0)
|
||||
return encoding
|
||||
|
||||
|
||||
def get_flow_features(flow):
|
||||
keys = ['duration', 'bytes_down', 'bytes_up']
|
||||
features = np.zeros([len(keys), ])
|
||||
for i, key in enumerate(keys):
|
||||
# TODO: does it still works after exceptions occur -- default: zero!
|
||||
# i wonder whether something brokes
|
||||
# if there are exceptions regarding to inconsistent feature length
|
||||
try:
|
||||
features[i] = np.log1p(flow[key]).astype(float)
|
||||
except:
|
||||
pass
|
||||
return features
|
||||
|
||||
|
||||
def get_all_flow_features(features):
|
||||
flows = np.stack(list(
|
||||
map(lambda f: f[["duration", "bytes_up", "bytes_down"]], features))
|
||||
@ -99,7 +85,7 @@ def get_all_flow_features(features):
|
||||
return np.log1p(flows)
|
||||
|
||||
|
||||
def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10, use_cisco_features=False):
|
||||
def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10):
|
||||
domains = []
|
||||
features = []
|
||||
print("get chunks from user data frames")
|
||||
@ -112,7 +98,7 @@ def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10,
|
||||
features += feature_windows
|
||||
|
||||
print("create training dataset")
|
||||
domain_tr, flow_tr, hits_tr, names_tr, server_tr, trusted_hits_tr = create_dataset_from_lists(domains=domains,
|
||||
domain_tr, flow_tr, hits_tr, _, server_tr, trusted_hits_tr = create_dataset_from_lists(domains=domains,
|
||||
flows=features,
|
||||
vocab=char_dict,
|
||||
max_len=max_len,
|
||||
@ -164,27 +150,20 @@ def create_dataset_from_lists(domains, flows, vocab, max_len, window_size=10):
|
||||
:param window_size: size of the flow window
|
||||
:return:
|
||||
"""
|
||||
numFeatures = 3
|
||||
sample_size = len(domains)
|
||||
hits = []
|
||||
names = []
|
||||
servers = []
|
||||
trusted_hits = []
|
||||
# sample_size = len(domains)
|
||||
|
||||
domain_features = np.zeros((sample_size, window_size, max_len))
|
||||
flow_features = np.zeros((sample_size, window_size, numFeatures))
|
||||
# domain_features = np.zeros((sample_size, window_size, max_len))
|
||||
flow_features = get_all_flow_features(flows)
|
||||
|
||||
for i in tqdm(np.arange(sample_size), miniters=10):
|
||||
for j in range(window_size):
|
||||
domain_features[i, j, :] = get_domain_features(domains[i][j], vocab, max_len)
|
||||
flow_features[i, j, :] = get_flow_features(flows[i].iloc[j])
|
||||
domain_features = np.array([[get_domain_features(d, vocab, max_len) for d in x] for x in domains])
|
||||
|
||||
hits = np.max(np.stack(map(lambda f: f.virusTotalHits, flows)), axis=1)
|
||||
names = np.unique(np.stack(map(lambda f: f.user_hash, flows)), axis=1)
|
||||
servers = np.max(np.stack(map(lambda f: f.serverLabel, flows)), axis=1)
|
||||
trusted_hits = np.max(np.stack(map(lambda f: f.trustedHits, flows)), axis=1)
|
||||
|
||||
hits.append(np.max(flows[i]['virusTotalHits']))
|
||||
names.append(np.unique(flows[i]['user_hash']))
|
||||
servers.append(np.max(flows[i]['serverLabel']))
|
||||
trusted_hits.append(np.max(flows[i]['trustedHits']))
|
||||
return (domain_features, flow_features,
|
||||
np.array(hits), np.array(names), np.array(servers), np.array(trusted_hits))
|
||||
hits, names, servers, trusted_hits)
|
||||
|
||||
|
||||
def discretize_label(values, threshold):
|
||||
|
86
main.py
86
main.py
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
|
||||
from keras.models import load_model
|
||||
|
||||
import dataset
|
||||
@ -21,8 +22,8 @@ parser.add_argument("--test", action="store", dest="test_data",
|
||||
# parser.add_argument("--h5data", action="store", dest="h5data",
|
||||
# default="")
|
||||
#
|
||||
parser.add_argument("--models", action="store", dest="models",
|
||||
default="models/model_x")
|
||||
parser.add_argument("--models", action="store", dest="model_path",
|
||||
default="models/models_x")
|
||||
|
||||
# parser.add_argument("--pred", action="store", dest="pred",
|
||||
# default="")
|
||||
@ -75,8 +76,9 @@ parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.embedding_model = args.models + "_embd.h5"
|
||||
args.clf_model = args.models + "_clf.h5"
|
||||
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
||||
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
||||
args.train_log = os.path.join(args.model_path, "train.log")
|
||||
args.h5data = args.train_data + ".h5"
|
||||
|
||||
|
||||
@ -93,21 +95,8 @@ def exists_or_make_path(p):
|
||||
|
||||
def main_paul_best():
|
||||
char_dict = dataset.get_character_dict()
|
||||
print("check for h5data")
|
||||
try:
|
||||
open(args.h5data, "r")
|
||||
raise FileNotFoundError()
|
||||
except FileNotFoundError:
|
||||
print("h5 data not found - load csv file")
|
||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
||||
print("create training dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
||||
user_flow_df, char_dict,
|
||||
max_len=args.domain_length, window_size=args.window)
|
||||
print("store training dataset as h5 file")
|
||||
dataset.store_h5dataset(args.h5data, domain_tr, flow_tr, client_tr, server_tr)
|
||||
print("load h5 dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.load_h5dataset(args.h5data)
|
||||
domain_tr, flow_tr, client_tr, server_tr = load_or_generate_h5data(args.h5data, args.train_data,
|
||||
args.domain_length, args.window)
|
||||
|
||||
param = models.pauls_networks.best_config
|
||||
param["vocab_size"] = len(char_dict) + 1
|
||||
@ -157,8 +146,7 @@ def main_hyperband():
|
||||
print(param)
|
||||
|
||||
print("create training dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
||||
user_flow_df, char_dict,
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
|
||||
max_len=args.domain_length,
|
||||
window_size=args.window)
|
||||
|
||||
@ -166,24 +154,30 @@ def main_hyperband():
|
||||
hp.run()
|
||||
|
||||
|
||||
def main_train():
|
||||
# exists_or_make_path(args.clf_model)
|
||||
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
||||
char_dict = dataset.get_character_dict()
|
||||
print("check for h5data")
|
||||
print("check for h5data", h5data)
|
||||
try:
|
||||
open(args.h5data, "r")
|
||||
raise FileNotFoundError()
|
||||
open(h5data, "r")
|
||||
except FileNotFoundError:
|
||||
print("h5 data not found - load csv file")
|
||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
||||
user_flow_df = dataset.get_user_flow_data(train_data)
|
||||
print("create training dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
||||
user_flow_df, char_dict,
|
||||
max_len=args.domain_length, window_size=args.window)
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
|
||||
max_len=domain_length,
|
||||
window_size=window_size)
|
||||
print("store training dataset as h5 file")
|
||||
dataset.store_h5dataset(args.h5data, domain_tr, flow_tr, client_tr, server_tr)
|
||||
print("load h5 dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.load_h5dataset(args.h5data)
|
||||
return dataset.load_h5dataset(h5data)
|
||||
|
||||
|
||||
def main_train():
|
||||
exists_or_make_path(args.model_path)
|
||||
|
||||
char_dict = dataset.get_character_dict()
|
||||
domain_tr, flow_tr, client_tr, server_tr = load_or_generate_h5data(args.h5data, args.train_data,
|
||||
args.domain_length, args.window)
|
||||
|
||||
# parameter
|
||||
param = {
|
||||
@ -210,34 +204,42 @@ def main_train():
|
||||
embedding, model = models.get_models_by_params(param)
|
||||
embedding.summary()
|
||||
model.summary()
|
||||
|
||||
print("define callbacks")
|
||||
cp = ModelCheckpoint(filepath=args.clf_model,
|
||||
monitor='val_loss',
|
||||
verbose=False,
|
||||
save_best_only=True)
|
||||
csv = CSVLogger(args.train_log)
|
||||
early = EarlyStopping(monitor='val_loss',
|
||||
patience=5,
|
||||
verbose=False)
|
||||
print("compile model")
|
||||
model.compile(optimizer='adam',
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
|
||||
print("start training")
|
||||
model.fit([domain_tr, flow_tr],
|
||||
[client_tr, server_tr],
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
callbacks=[cp, csv, early],
|
||||
shuffle=True,
|
||||
validation_split=0.2)
|
||||
|
||||
print("save embedding")
|
||||
embedding.save(args.embedding_model)
|
||||
model.save(args.clf_model)
|
||||
|
||||
|
||||
def main_test():
|
||||
char_dict = dataset.get_character_dict()
|
||||
user_flow_df = dataset.get_user_flow_data(args.test_data)
|
||||
domain_val, flow_val, client_val, server_val = dataset.create_dataset_from_flows(
|
||||
user_flow_df, char_dict,
|
||||
max_len=args.domain_length, window_size=args.window)
|
||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.h5data, args.train_data,
|
||||
args.domain_length, args.window)
|
||||
# embedding = load_model(args.embedding_model)
|
||||
clf = load_model(args.clf_model)
|
||||
|
||||
print(clf.evaluate([domain_val, flow_val],
|
||||
loss, _, _, client_acc, server_acc = clf.evaluate([domain_val, flow_val],
|
||||
[client_val, server_val],
|
||||
batch_size=args.batch_size))
|
||||
batch_size=args.batch_size)
|
||||
|
||||
print(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}")
|
||||
|
||||
|
||||
def main_visualization():
|
||||
|
Loading…
Reference in New Issue
Block a user