refactor dataset generation, add callbacks
This commit is contained in:
parent
a196daa895
commit
9f0bae33d5
57
dataset.py
57
dataset.py
@ -25,6 +25,7 @@ def encode_char(c):
|
|||||||
encode_char = np.vectorize(encode_char)
|
encode_char = np.vectorize(encode_char)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: refactor
|
||||||
def get_user_chunks(dataFrame, windowSize=10, overlapping=False,
|
def get_user_chunks(dataFrame, windowSize=10, overlapping=False,
|
||||||
maxLengthInSeconds=300):
|
maxLengthInSeconds=300):
|
||||||
maxMilliSeconds = maxLengthInSeconds * 1000
|
maxMilliSeconds = maxLengthInSeconds * 1000
|
||||||
@ -66,32 +67,17 @@ def get_user_chunks(dataFrame, windowSize=10, overlapping=False,
|
|||||||
if len(outDomainLists[-1]) != windowSize:
|
if len(outDomainLists[-1]) != windowSize:
|
||||||
outDomainLists.pop(-1)
|
outDomainLists.pop(-1)
|
||||||
outDFFrames.pop(-1)
|
outDFFrames.pop(-1)
|
||||||
return (outDomainLists, outDFFrames)
|
return outDomainLists, outDFFrames
|
||||||
|
|
||||||
|
|
||||||
def get_domain_features(domain, vocab, max_length=40):
|
def get_domain_features(domain, vocab: dict, max_length=40):
|
||||||
encoding = np.zeros((max_length,))
|
encoding = np.zeros((max_length,))
|
||||||
for j in range(np.min([len(domain), max_length])):
|
for j in range(min(len(domain), max_length)):
|
||||||
char = domain[-j]
|
char = domain[-j] # TODO: why -j -> order reversed for domain url?
|
||||||
if char in vocab:
|
encoding[j] = vocab.get(char, 0)
|
||||||
encoding[j] = vocab[char]
|
|
||||||
return encoding
|
return encoding
|
||||||
|
|
||||||
|
|
||||||
def get_flow_features(flow):
|
|
||||||
keys = ['duration', 'bytes_down', 'bytes_up']
|
|
||||||
features = np.zeros([len(keys), ])
|
|
||||||
for i, key in enumerate(keys):
|
|
||||||
# TODO: does it still works after exceptions occur -- default: zero!
|
|
||||||
# i wonder whether something brokes
|
|
||||||
# if there are exceptions regarding to inconsistent feature length
|
|
||||||
try:
|
|
||||||
features[i] = np.log1p(flow[key]).astype(float)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return features
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_flow_features(features):
|
def get_all_flow_features(features):
|
||||||
flows = np.stack(list(
|
flows = np.stack(list(
|
||||||
map(lambda f: f[["duration", "bytes_up", "bytes_down"]], features))
|
map(lambda f: f[["duration", "bytes_up", "bytes_down"]], features))
|
||||||
@ -99,7 +85,7 @@ def get_all_flow_features(features):
|
|||||||
return np.log1p(flows)
|
return np.log1p(flows)
|
||||||
|
|
||||||
|
|
||||||
def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10, use_cisco_features=False):
|
def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10):
|
||||||
domains = []
|
domains = []
|
||||||
features = []
|
features = []
|
||||||
print("get chunks from user data frames")
|
print("get chunks from user data frames")
|
||||||
@ -112,7 +98,7 @@ def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10,
|
|||||||
features += feature_windows
|
features += feature_windows
|
||||||
|
|
||||||
print("create training dataset")
|
print("create training dataset")
|
||||||
domain_tr, flow_tr, hits_tr, names_tr, server_tr, trusted_hits_tr = create_dataset_from_lists(domains=domains,
|
domain_tr, flow_tr, hits_tr, _, server_tr, trusted_hits_tr = create_dataset_from_lists(domains=domains,
|
||||||
flows=features,
|
flows=features,
|
||||||
vocab=char_dict,
|
vocab=char_dict,
|
||||||
max_len=max_len,
|
max_len=max_len,
|
||||||
@ -164,27 +150,20 @@ def create_dataset_from_lists(domains, flows, vocab, max_len, window_size=10):
|
|||||||
:param window_size: size of the flow window
|
:param window_size: size of the flow window
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
numFeatures = 3
|
# sample_size = len(domains)
|
||||||
sample_size = len(domains)
|
|
||||||
hits = []
|
|
||||||
names = []
|
|
||||||
servers = []
|
|
||||||
trusted_hits = []
|
|
||||||
|
|
||||||
domain_features = np.zeros((sample_size, window_size, max_len))
|
# domain_features = np.zeros((sample_size, window_size, max_len))
|
||||||
flow_features = np.zeros((sample_size, window_size, numFeatures))
|
flow_features = get_all_flow_features(flows)
|
||||||
|
|
||||||
for i in tqdm(np.arange(sample_size), miniters=10):
|
domain_features = np.array([[get_domain_features(d, vocab, max_len) for d in x] for x in domains])
|
||||||
for j in range(window_size):
|
|
||||||
domain_features[i, j, :] = get_domain_features(domains[i][j], vocab, max_len)
|
hits = np.max(np.stack(map(lambda f: f.virusTotalHits, flows)), axis=1)
|
||||||
flow_features[i, j, :] = get_flow_features(flows[i].iloc[j])
|
names = np.unique(np.stack(map(lambda f: f.user_hash, flows)), axis=1)
|
||||||
|
servers = np.max(np.stack(map(lambda f: f.serverLabel, flows)), axis=1)
|
||||||
|
trusted_hits = np.max(np.stack(map(lambda f: f.trustedHits, flows)), axis=1)
|
||||||
|
|
||||||
hits.append(np.max(flows[i]['virusTotalHits']))
|
|
||||||
names.append(np.unique(flows[i]['user_hash']))
|
|
||||||
servers.append(np.max(flows[i]['serverLabel']))
|
|
||||||
trusted_hits.append(np.max(flows[i]['trustedHits']))
|
|
||||||
return (domain_features, flow_features,
|
return (domain_features, flow_features,
|
||||||
np.array(hits), np.array(names), np.array(servers), np.array(trusted_hits))
|
hits, names, servers, trusted_hits)
|
||||||
|
|
||||||
|
|
||||||
def discretize_label(values, threshold):
|
def discretize_label(values, threshold):
|
||||||
|
86
main.py
86
main.py
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
|
||||||
from keras.models import load_model
|
from keras.models import load_model
|
||||||
|
|
||||||
import dataset
|
import dataset
|
||||||
@ -21,8 +22,8 @@ parser.add_argument("--test", action="store", dest="test_data",
|
|||||||
# parser.add_argument("--h5data", action="store", dest="h5data",
|
# parser.add_argument("--h5data", action="store", dest="h5data",
|
||||||
# default="")
|
# default="")
|
||||||
#
|
#
|
||||||
parser.add_argument("--models", action="store", dest="models",
|
parser.add_argument("--models", action="store", dest="model_path",
|
||||||
default="models/model_x")
|
default="models/models_x")
|
||||||
|
|
||||||
# parser.add_argument("--pred", action="store", dest="pred",
|
# parser.add_argument("--pred", action="store", dest="pred",
|
||||||
# default="")
|
# default="")
|
||||||
@ -75,8 +76,9 @@ parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args.embedding_model = args.models + "_embd.h5"
|
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
||||||
args.clf_model = args.models + "_clf.h5"
|
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
||||||
|
args.train_log = os.path.join(args.model_path, "train.log")
|
||||||
args.h5data = args.train_data + ".h5"
|
args.h5data = args.train_data + ".h5"
|
||||||
|
|
||||||
|
|
||||||
@ -93,21 +95,8 @@ def exists_or_make_path(p):
|
|||||||
|
|
||||||
def main_paul_best():
|
def main_paul_best():
|
||||||
char_dict = dataset.get_character_dict()
|
char_dict = dataset.get_character_dict()
|
||||||
print("check for h5data")
|
domain_tr, flow_tr, client_tr, server_tr = load_or_generate_h5data(args.h5data, args.train_data,
|
||||||
try:
|
args.domain_length, args.window)
|
||||||
open(args.h5data, "r")
|
|
||||||
raise FileNotFoundError()
|
|
||||||
except FileNotFoundError:
|
|
||||||
print("h5 data not found - load csv file")
|
|
||||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
|
||||||
print("create training dataset")
|
|
||||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
|
||||||
user_flow_df, char_dict,
|
|
||||||
max_len=args.domain_length, window_size=args.window)
|
|
||||||
print("store training dataset as h5 file")
|
|
||||||
dataset.store_h5dataset(args.h5data, domain_tr, flow_tr, client_tr, server_tr)
|
|
||||||
print("load h5 dataset")
|
|
||||||
domain_tr, flow_tr, client_tr, server_tr = dataset.load_h5dataset(args.h5data)
|
|
||||||
|
|
||||||
param = models.pauls_networks.best_config
|
param = models.pauls_networks.best_config
|
||||||
param["vocab_size"] = len(char_dict) + 1
|
param["vocab_size"] = len(char_dict) + 1
|
||||||
@ -157,8 +146,7 @@ def main_hyperband():
|
|||||||
print(param)
|
print(param)
|
||||||
|
|
||||||
print("create training dataset")
|
print("create training dataset")
|
||||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
|
||||||
user_flow_df, char_dict,
|
|
||||||
max_len=args.domain_length,
|
max_len=args.domain_length,
|
||||||
window_size=args.window)
|
window_size=args.window)
|
||||||
|
|
||||||
@ -166,24 +154,30 @@ def main_hyperband():
|
|||||||
hp.run()
|
hp.run()
|
||||||
|
|
||||||
|
|
||||||
def main_train():
|
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
||||||
# exists_or_make_path(args.clf_model)
|
|
||||||
char_dict = dataset.get_character_dict()
|
char_dict = dataset.get_character_dict()
|
||||||
print("check for h5data")
|
print("check for h5data", h5data)
|
||||||
try:
|
try:
|
||||||
open(args.h5data, "r")
|
open(h5data, "r")
|
||||||
raise FileNotFoundError()
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print("h5 data not found - load csv file")
|
print("h5 data not found - load csv file")
|
||||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
user_flow_df = dataset.get_user_flow_data(train_data)
|
||||||
print("create training dataset")
|
print("create training dataset")
|
||||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
|
||||||
user_flow_df, char_dict,
|
max_len=domain_length,
|
||||||
max_len=args.domain_length, window_size=args.window)
|
window_size=window_size)
|
||||||
print("store training dataset as h5 file")
|
print("store training dataset as h5 file")
|
||||||
dataset.store_h5dataset(args.h5data, domain_tr, flow_tr, client_tr, server_tr)
|
dataset.store_h5dataset(args.h5data, domain_tr, flow_tr, client_tr, server_tr)
|
||||||
print("load h5 dataset")
|
print("load h5 dataset")
|
||||||
domain_tr, flow_tr, client_tr, server_tr = dataset.load_h5dataset(args.h5data)
|
return dataset.load_h5dataset(h5data)
|
||||||
|
|
||||||
|
|
||||||
|
def main_train():
|
||||||
|
exists_or_make_path(args.model_path)
|
||||||
|
|
||||||
|
char_dict = dataset.get_character_dict()
|
||||||
|
domain_tr, flow_tr, client_tr, server_tr = load_or_generate_h5data(args.h5data, args.train_data,
|
||||||
|
args.domain_length, args.window)
|
||||||
|
|
||||||
# parameter
|
# parameter
|
||||||
param = {
|
param = {
|
||||||
@ -210,34 +204,42 @@ def main_train():
|
|||||||
embedding, model = models.get_models_by_params(param)
|
embedding, model = models.get_models_by_params(param)
|
||||||
embedding.summary()
|
embedding.summary()
|
||||||
model.summary()
|
model.summary()
|
||||||
|
print("define callbacks")
|
||||||
|
cp = ModelCheckpoint(filepath=args.clf_model,
|
||||||
|
monitor='val_loss',
|
||||||
|
verbose=False,
|
||||||
|
save_best_only=True)
|
||||||
|
csv = CSVLogger(args.train_log)
|
||||||
|
early = EarlyStopping(monitor='val_loss',
|
||||||
|
patience=5,
|
||||||
|
verbose=False)
|
||||||
|
print("compile model")
|
||||||
model.compile(optimizer='adam',
|
model.compile(optimizer='adam',
|
||||||
loss='categorical_crossentropy',
|
loss='categorical_crossentropy',
|
||||||
metrics=['accuracy'])
|
metrics=['accuracy'])
|
||||||
|
print("start training")
|
||||||
model.fit([domain_tr, flow_tr],
|
model.fit([domain_tr, flow_tr],
|
||||||
[client_tr, server_tr],
|
[client_tr, server_tr],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
|
callbacks=[cp, csv, early],
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
validation_split=0.2)
|
validation_split=0.2)
|
||||||
|
print("save embedding")
|
||||||
embedding.save(args.embedding_model)
|
embedding.save(args.embedding_model)
|
||||||
model.save(args.clf_model)
|
|
||||||
|
|
||||||
|
|
||||||
def main_test():
|
def main_test():
|
||||||
char_dict = dataset.get_character_dict()
|
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.h5data, args.train_data,
|
||||||
user_flow_df = dataset.get_user_flow_data(args.test_data)
|
args.domain_length, args.window)
|
||||||
domain_val, flow_val, client_val, server_val = dataset.create_dataset_from_flows(
|
|
||||||
user_flow_df, char_dict,
|
|
||||||
max_len=args.domain_length, window_size=args.window)
|
|
||||||
# embedding = load_model(args.embedding_model)
|
# embedding = load_model(args.embedding_model)
|
||||||
clf = load_model(args.clf_model)
|
clf = load_model(args.clf_model)
|
||||||
|
|
||||||
print(clf.evaluate([domain_val, flow_val],
|
loss, _, _, client_acc, server_acc = clf.evaluate([domain_val, flow_val],
|
||||||
[client_val, server_val],
|
[client_val, server_val],
|
||||||
batch_size=args.batch_size))
|
batch_size=args.batch_size)
|
||||||
|
|
||||||
|
print(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}")
|
||||||
|
|
||||||
|
|
||||||
def main_visualization():
|
def main_visualization():
|
||||||
|
Loading…
Reference in New Issue
Block a user