refactor argparser into separate file, add logger

This commit is contained in:
René Knaebel 2017-07-12 10:25:55 +02:00
parent 9f0bae33d5
commit 2afaccc84b
5 changed files with 143 additions and 100 deletions

View File

@ -1,5 +1,5 @@
test:
python3 main.py --modes train --epochs 1 --batch 64 --train data/rk_data.csv.gz
python3 main.py --modes train --epochs 1 --batch 128 --train data/rk_mini.csv.gz
hyper:
python3 main.py --modes hyperband --epochs 1 --batch 64 --train data/rk_data.csv.gz

78
arguments.py Normal file
View File

@ -0,0 +1,78 @@
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument("--modes", action="store", dest="modes", nargs="+",
default=[])
parser.add_argument("--train", action="store", dest="train_data",
default="data/full_dataset.csv.tar.bz2")
parser.add_argument("--test", action="store", dest="test_data",
default="data/full_future_dataset.csv.tar.bz2")
# parser.add_argument("--h5data", action="store", dest="h5data",
# default="")
#
parser.add_argument("--models", action="store", dest="model_path",
default="models/models_x")
# parser.add_argument("--pred", action="store", dest="pred",
# default="")
#
parser.add_argument("--type", action="store", dest="model_type",
default="paul")
parser.add_argument("--batch", action="store", dest="batch_size",
default=64, type=int)
parser.add_argument("--epochs", action="store", dest="epochs",
default=10, type=int)
# parser.add_argument("--samples", action="store", dest="samples",
# default=100000, type=int)
#
# parser.add_argument("--samples_val", action="store", dest="samples_val",
# default=10000, type=int)
#
parser.add_argument("--embd", action="store", dest="embedding",
default=128, type=int)
parser.add_argument("--hidden_char_dims", action="store", dest="hidden_char_dims",
default=256, type=int)
parser.add_argument("--window", action="store", dest="window",
default=10, type=int)
parser.add_argument("--domain_length", action="store", dest="domain_length",
default=40, type=int)
parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
default=512, type=int)
# parser.add_argument("--queue", action="store", dest="queue_size",
# default=50, type=int)
#
# parser.add_argument("--p", action="store", dest="p_train",
# default=0.5, type=float)
#
# parser.add_argument("--p_val", action="store", dest="p_val",
# default=0.01, type=float)
#
# parser.add_argument("--gpu", action="store", dest="gpu",
# default=0, type=int)
#
# parser.add_argument("--tmp", action="store_true", dest="tmp")
#
# parser.add_argument("--test", action="store_true", dest="test")
def parse():
args = parser.parse_args()
args.embedding_model = os.path.join(args.model_path, "embd.h5")
args.clf_model = os.path.join(args.model_path, "clf.h5")
args.train_log = os.path.join(args.model_path, "train.log")
args.h5data = args.train_data + ".h5"
return args

View File

@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import logging
import string
import h5py
@ -7,6 +8,8 @@ import pandas as pd
from keras.utils import np_utils
from tqdm import tqdm
logger = logging.getLogger('logger')
chars = dict((char, idx + 1) for (idx, char) in
enumerate(string.ascii_lowercase + string.punctuation + string.digits))
@ -36,7 +39,7 @@ def get_user_chunks(dataFrame, windowSize=10, overlapping=False,
userIDs = np.arange(len(dataFrame))
for blockID in np.arange(numBlocks):
curIDs = userIDs[(blockID * windowSize):((blockID + 1) * windowSize)]
# print(curIDs)
# logger.info(curIDs)
useData = dataFrame.iloc[curIDs]
curDomains = useData['domain']
if maxLengthInSeconds != -1:
@ -88,7 +91,7 @@ def get_all_flow_features(features):
def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10):
domains = []
features = []
print("get chunks from user data frames")
logger.info("get chunks from user data frames")
for i, user_flow in tqdm(list(enumerate(get_flow_per_user(user_flow_df)))):
(domain_windows, feature_windows) = get_user_chunks(user_flow,
windowSize=window_size,
@ -97,7 +100,7 @@ def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10):
domains += domain_windows
features += feature_windows
print("create training dataset")
logger.info("create training dataset")
domain_tr, flow_tr, hits_tr, _, server_tr, trusted_hits_tr = create_dataset_from_lists(domains=domains,
flows=features,
vocab=char_dict,
@ -150,13 +153,8 @@ def create_dataset_from_lists(domains, flows, vocab, max_len, window_size=10):
:param window_size: size of the flow window
:return:
"""
# sample_size = len(domains)
# domain_features = np.zeros((sample_size, window_size, max_len))
flow_features = get_all_flow_features(flows)
domain_features = np.array([[get_domain_features(d, vocab, max_len) for d in x] for x in domains])
flow_features = get_all_flow_features(flows)
hits = np.max(np.stack(map(lambda f: f.virusTotalHits, flows)), axis=1)
names = np.unique(np.stack(map(lambda f: f.user_hash, flows)), axis=1)
servers = np.max(np.stack(map(lambda f: f.serverLabel, flows)), axis=1)

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# implementation of hyperband:
# https://arxiv.org/pdf/1603.06560.pdf
import logging
import random
from math import log, ceil
from random import random as rng
@ -10,6 +11,8 @@ import numpy as np
import models
logger = logging.getLogger('logger')
def sample_params(param_distribution: dict):
p = {}
@ -75,7 +78,7 @@ class Hyperband:
n_configs = n * self.eta ** (-i)
n_iterations = r * self.eta ** (i)
print("\n*** {} configurations x {:.1f} iterations each".format(
logger.info("\n*** {} configurations x {:.1f} iterations each".format(
n_configs, n_iterations))
val_losses = []
@ -84,7 +87,7 @@ class Hyperband:
for t in T:
self.counter += 1
print("\n{} | {} | lowest loss so far: {:.4f} (run {})\n".format(
logger.info("\n{} | {} | lowest loss so far: {:.4f} (run {})\n".format(
self.counter, ctime(), self.best_loss, self.best_counter))
start_time = time()
@ -98,7 +101,7 @@ class Hyperband:
assert ('loss' in result)
seconds = int(round(time() - start_time))
print("\n{} seconds.".format(seconds))
logger.info("\n{} seconds.".format(seconds))
loss = result['loss']
val_losses.append(loss)

138
main.py
View File

@ -1,85 +1,46 @@
import argparse
import json
import logging
import os
import numpy as np
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
from keras.models import load_model
import arguments
import dataset
import hyperband
import models
parser = argparse.ArgumentParser()
# create logger
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
parser.add_argument("--modes", action="store", dest="modes", nargs="+",
default=[])
# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
parser.add_argument("--train", action="store", dest="train_data",
default="data/full_dataset.csv.tar.bz2")
# create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
parser.add_argument("--test", action="store", dest="test_data",
default="data/full_future_dataset.csv.tar.bz2")
# add formatter to ch
ch.setFormatter(formatter)
# parser.add_argument("--h5data", action="store", dest="h5data",
# default="")
#
parser.add_argument("--models", action="store", dest="model_path",
default="models/models_x")
# add ch to logger
logger.addHandler(ch)
# parser.add_argument("--pred", action="store", dest="pred",
# default="")
#
parser.add_argument("--type", action="store", dest="model_type",
default="paul")
ch = logging.FileHandler("info.log")
ch.setLevel(logging.DEBUG)
parser.add_argument("--batch", action="store", dest="batch_size",
default=64, type=int)
# create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
parser.add_argument("--epochs", action="store", dest="epochs",
default=10, type=int)
# add formatter to ch
ch.setFormatter(formatter)
# parser.add_argument("--samples", action="store", dest="samples",
# default=100000, type=int)
#
# parser.add_argument("--samples_val", action="store", dest="samples_val",
# default=10000, type=int)
#
parser.add_argument("--embd", action="store", dest="embedding",
default=128, type=int)
# add ch to logger
logger.addHandler(ch)
parser.add_argument("--hidden_char_dims", action="store", dest="hidden_char_dims",
default=256, type=int)
parser.add_argument("--window", action="store", dest="window",
default=10, type=int)
parser.add_argument("--domain_length", action="store", dest="domain_length",
default=40, type=int)
parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
default=512, type=int)
# parser.add_argument("--queue", action="store", dest="queue_size",
# default=50, type=int)
#
# parser.add_argument("--p", action="store", dest="p_train",
# default=0.5, type=float)
#
# parser.add_argument("--p_val", action="store", dest="p_val",
# default=0.01, type=float)
#
# parser.add_argument("--gpu", action="store", dest="gpu",
# default=0, type=int)
#
# parser.add_argument("--tmp", action="store_true", dest="tmp")
#
# parser.add_argument("--test", action="store_true", dest="test")
args = parser.parse_args()
args.embedding_model = os.path.join(args.model_path, "embd.h5")
args.clf_model = os.path.join(args.model_path, "clf.h5")
args.train_log = os.path.join(args.model_path, "train.log")
args.h5data = args.train_data + ".h5"
args = arguments.parse()
# config = tf.ConfigProto(log_device_placement=True)
@ -125,7 +86,7 @@ def main_hyperband():
params = {
# static params
"type": ["paul"],
"batch_size": [64],
"batch_size": [args.batch_size],
"vocab_size": [len(char_dict) + 1],
"window_size": [10],
"domain_length": [40],
@ -143,32 +104,35 @@ def main_hyperband():
"dense_main": [16, 32, 64, 128, 256, 512],
}
param = hyperband.sample_params(params)
print(param)
logger.info(param)
print("create training dataset")
logger.info("create training dataset")
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
max_len=args.domain_length,
window_size=args.window)
hp = hyperband.Hyperband(params, [domain_tr, flow_tr], [client_tr, server_tr])
hp.run()
hp = hyperband.Hyperband(params,
[domain_tr, flow_tr],
[client_tr, server_tr])
results = hp.run()
json.dump(results, open("hyperband.json"))
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
char_dict = dataset.get_character_dict()
print("check for h5data", h5data)
logger.info(f"check for h5data {h5data}")
try:
open(h5data, "r")
except FileNotFoundError:
print("h5 data not found - load csv file")
logger.info("h5 data not found - load csv file")
user_flow_df = dataset.get_user_flow_data(train_data)
print("create training dataset")
logger.info("create training dataset")
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
max_len=domain_length,
window_size=window_size)
print("store training dataset as h5 file")
logger.info("store training dataset as h5 file")
dataset.store_h5dataset(args.h5data, domain_tr, flow_tr, client_tr, server_tr)
print("load h5 dataset")
logger.info("load h5 dataset")
return dataset.load_h5dataset(h5data)
@ -204,7 +168,7 @@ def main_train():
embedding, model = models.get_models_by_params(param)
embedding.summary()
model.summary()
print("define callbacks")
logger.info("define callbacks")
cp = ModelCheckpoint(filepath=args.clf_model,
monitor='val_loss',
verbose=False,
@ -213,11 +177,11 @@ def main_train():
early = EarlyStopping(monitor='val_loss',
patience=5,
verbose=False)
print("compile model")
logger.info("compile model")
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
print("start training")
logger.info("start training")
model.fit([domain_tr, flow_tr],
[client_tr, server_tr],
batch_size=args.batch_size,
@ -225,40 +189,40 @@ def main_train():
callbacks=[cp, csv, early],
shuffle=True,
validation_split=0.2)
print("save embedding")
logger.info("save embedding")
embedding.save(args.embedding_model)
def main_test():
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.h5data, args.train_data,
args.domain_length, args.window)
# embedding = load_model(args.embedding_model)
clf = load_model(args.clf_model)
loss, _, _, client_acc, server_acc = clf.evaluate([domain_val, flow_val],
[client_val, server_val],
batch_size=args.batch_size)
print(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}")
logger.info(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}")
y_pred = clf.predict([domain_val, flow_val],
batch_size=args.batch_size)
np.save(os.path.join(args.model_path, "future_predict.npy"), y_pred)
def main_visualization():
mask = dataset.load_mask_eval(args.data, args.test_image)
y_pred_path = args.model_path + "pred.npy"
print("plot model")
logger.info("plot model")
model = load_model(args.model_path + "model.h5",
custom_objects=evaluation.get_metrics())
visualize.plot_model(model, args.model_path + "model.png")
print("plot training curve")
logger.info("plot training curve")
logs = pd.read_csv(args.model_path + "train.log")
visualize.plot_training_curve(logs, "{}/train.png".format(args.model_path))
pred = np.load(y_pred_path)
print("plot pr curve")
logger.info("plot pr curve")
visualize.plot_precision_recall(mask, pred, "{}/prc.png".format(args.model_path))
visualize.plot_precision_recall_curves(mask, pred, "{}/prc2.png".format(args.model_path))
print("plot roc curve")
logger.info("plot roc curve")
visualize.plot_roc_curve(mask, pred, "{}/roc.png".format(args.model_path))
print("store prediction image")
logger.info("store prediction image")
visualize.save_image_as(pred, "{}/pred.png".format(args.model_path))