From 2afaccc84b2d63ed69e7c9e7f3502fb2d9a7fba4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Wed, 12 Jul 2017 10:25:55 +0200 Subject: [PATCH] refactor argparser into separate file, add logger --- Makefile | 2 +- arguments.py | 78 +++++++++++++++++++++++++++++ dataset.py | 16 +++--- hyperband.py | 9 ++-- main.py | 138 +++++++++++++++++++-------------------------------- 5 files changed, 143 insertions(+), 100 deletions(-) create mode 100644 arguments.py diff --git a/Makefile b/Makefile index 579b107..4a2d47e 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ test: - python3 main.py --modes train --epochs 1 --batch 64 --train data/rk_data.csv.gz + python3 main.py --modes train --epochs 1 --batch 128 --train data/rk_mini.csv.gz hyper: python3 main.py --modes hyperband --epochs 1 --batch 64 --train data/rk_data.csv.gz diff --git a/arguments.py b/arguments.py new file mode 100644 index 0000000..740a778 --- /dev/null +++ b/arguments.py @@ -0,0 +1,78 @@ +import argparse +import os + +parser = argparse.ArgumentParser() + +parser.add_argument("--modes", action="store", dest="modes", nargs="+", + default=[]) + +parser.add_argument("--train", action="store", dest="train_data", + default="data/full_dataset.csv.tar.bz2") + +parser.add_argument("--test", action="store", dest="test_data", + default="data/full_future_dataset.csv.tar.bz2") + +# parser.add_argument("--h5data", action="store", dest="h5data", +# default="") +# +parser.add_argument("--models", action="store", dest="model_path", + default="models/models_x") + +# parser.add_argument("--pred", action="store", dest="pred", +# default="") +# +parser.add_argument("--type", action="store", dest="model_type", + default="paul") + +parser.add_argument("--batch", action="store", dest="batch_size", + default=64, type=int) + +parser.add_argument("--epochs", action="store", dest="epochs", + default=10, type=int) + +# parser.add_argument("--samples", action="store", dest="samples", +# default=100000, type=int) +# +# parser.add_argument("--samples_val", action="store", dest="samples_val", +# default=10000, type=int) +# +parser.add_argument("--embd", action="store", dest="embedding", + default=128, type=int) + +parser.add_argument("--hidden_char_dims", action="store", dest="hidden_char_dims", + default=256, type=int) + +parser.add_argument("--window", action="store", dest="window", + default=10, type=int) + +parser.add_argument("--domain_length", action="store", dest="domain_length", + default=40, type=int) + +parser.add_argument("--domain_embd", action="store", dest="domain_embedding", + default=512, type=int) + + +# parser.add_argument("--queue", action="store", dest="queue_size", +# default=50, type=int) +# +# parser.add_argument("--p", action="store", dest="p_train", +# default=0.5, type=float) +# +# parser.add_argument("--p_val", action="store", dest="p_val", +# default=0.01, type=float) +# +# parser.add_argument("--gpu", action="store", dest="gpu", +# default=0, type=int) +# +# parser.add_argument("--tmp", action="store_true", dest="tmp") +# +# parser.add_argument("--test", action="store_true", dest="test") + + +def parse(): + args = parser.parse_args() + args.embedding_model = os.path.join(args.model_path, "embd.h5") + args.clf_model = os.path.join(args.model_path, "clf.h5") + args.train_log = os.path.join(args.model_path, "train.log") + args.h5data = args.train_data + ".h5" + return args diff --git a/dataset.py b/dataset.py index b0ecdd6..c759b64 100644 --- a/dataset.py +++ b/dataset.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import logging import string import h5py @@ -7,6 +8,8 @@ import pandas as pd from keras.utils import np_utils from tqdm import tqdm +logger = logging.getLogger('logger') + chars = dict((char, idx + 1) for (idx, char) in enumerate(string.ascii_lowercase + string.punctuation + string.digits)) @@ -36,7 +39,7 @@ def get_user_chunks(dataFrame, windowSize=10, overlapping=False, userIDs = np.arange(len(dataFrame)) for blockID in np.arange(numBlocks): curIDs = userIDs[(blockID * windowSize):((blockID + 1) * windowSize)] - # print(curIDs) + # logger.info(curIDs) useData = dataFrame.iloc[curIDs] curDomains = useData['domain'] if maxLengthInSeconds != -1: @@ -88,7 +91,7 @@ def get_all_flow_features(features): def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10): domains = [] features = [] - print("get chunks from user data frames") + logger.info("get chunks from user data frames") for i, user_flow in tqdm(list(enumerate(get_flow_per_user(user_flow_df)))): (domain_windows, feature_windows) = get_user_chunks(user_flow, windowSize=window_size, @@ -97,7 +100,7 @@ def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10): domains += domain_windows features += feature_windows - print("create training dataset") + logger.info("create training dataset") domain_tr, flow_tr, hits_tr, _, server_tr, trusted_hits_tr = create_dataset_from_lists(domains=domains, flows=features, vocab=char_dict, @@ -150,13 +153,8 @@ def create_dataset_from_lists(domains, flows, vocab, max_len, window_size=10): :param window_size: size of the flow window :return: """ - # sample_size = len(domains) - - # domain_features = np.zeros((sample_size, window_size, max_len)) - flow_features = get_all_flow_features(flows) - domain_features = np.array([[get_domain_features(d, vocab, max_len) for d in x] for x in domains]) - + flow_features = get_all_flow_features(flows) hits = np.max(np.stack(map(lambda f: f.virusTotalHits, flows)), axis=1) names = np.unique(np.stack(map(lambda f: f.user_hash, flows)), axis=1) servers = np.max(np.stack(map(lambda f: f.serverLabel, flows)), axis=1) diff --git a/hyperband.py b/hyperband.py index 641d4c6..c9f30c2 100644 --- a/hyperband.py +++ b/hyperband.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # implementation of hyperband: # https://arxiv.org/pdf/1603.06560.pdf +import logging import random from math import log, ceil from random import random as rng @@ -10,6 +11,8 @@ import numpy as np import models +logger = logging.getLogger('logger') + def sample_params(param_distribution: dict): p = {} @@ -75,7 +78,7 @@ class Hyperband: n_configs = n * self.eta ** (-i) n_iterations = r * self.eta ** (i) - print("\n*** {} configurations x {:.1f} iterations each".format( + logger.info("\n*** {} configurations x {:.1f} iterations each".format( n_configs, n_iterations)) val_losses = [] @@ -84,7 +87,7 @@ class Hyperband: for t in T: self.counter += 1 - print("\n{} | {} | lowest loss so far: {:.4f} (run {})\n".format( + logger.info("\n{} | {} | lowest loss so far: {:.4f} (run {})\n".format( self.counter, ctime(), self.best_loss, self.best_counter)) start_time = time() @@ -98,7 +101,7 @@ class Hyperband: assert ('loss' in result) seconds = int(round(time() - start_time)) - print("\n{} seconds.".format(seconds)) + logger.info("\n{} seconds.".format(seconds)) loss = result['loss'] val_losses.append(loss) diff --git a/main.py b/main.py index ba1455a..86a8bcb 100644 --- a/main.py +++ b/main.py @@ -1,85 +1,46 @@ -import argparse +import json +import logging import os +import numpy as np from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping from keras.models import load_model +import arguments import dataset import hyperband import models -parser = argparse.ArgumentParser() +# create logger +logger = logging.getLogger('logger') +logger.setLevel(logging.DEBUG) -parser.add_argument("--modes", action="store", dest="modes", nargs="+", - default=[]) +# create console handler and set level to debug +ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) -parser.add_argument("--train", action="store", dest="train_data", - default="data/full_dataset.csv.tar.bz2") +# create formatter +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') -parser.add_argument("--test", action="store", dest="test_data", - default="data/full_future_dataset.csv.tar.bz2") +# add formatter to ch +ch.setFormatter(formatter) -# parser.add_argument("--h5data", action="store", dest="h5data", -# default="") -# -parser.add_argument("--models", action="store", dest="model_path", - default="models/models_x") +# add ch to logger +logger.addHandler(ch) -# parser.add_argument("--pred", action="store", dest="pred", -# default="") -# -parser.add_argument("--type", action="store", dest="model_type", - default="paul") +ch = logging.FileHandler("info.log") +ch.setLevel(logging.DEBUG) -parser.add_argument("--batch", action="store", dest="batch_size", - default=64, type=int) +# create formatter +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') -parser.add_argument("--epochs", action="store", dest="epochs", - default=10, type=int) +# add formatter to ch +ch.setFormatter(formatter) -# parser.add_argument("--samples", action="store", dest="samples", -# default=100000, type=int) -# -# parser.add_argument("--samples_val", action="store", dest="samples_val", -# default=10000, type=int) -# -parser.add_argument("--embd", action="store", dest="embedding", - default=128, type=int) +# add ch to logger +logger.addHandler(ch) -parser.add_argument("--hidden_char_dims", action="store", dest="hidden_char_dims", - default=256, type=int) - -parser.add_argument("--window", action="store", dest="window", - default=10, type=int) - -parser.add_argument("--domain_length", action="store", dest="domain_length", - default=40, type=int) - -parser.add_argument("--domain_embd", action="store", dest="domain_embedding", - default=512, type=int) - -# parser.add_argument("--queue", action="store", dest="queue_size", -# default=50, type=int) -# -# parser.add_argument("--p", action="store", dest="p_train", -# default=0.5, type=float) -# -# parser.add_argument("--p_val", action="store", dest="p_val", -# default=0.01, type=float) -# -# parser.add_argument("--gpu", action="store", dest="gpu", -# default=0, type=int) -# -# parser.add_argument("--tmp", action="store_true", dest="tmp") -# -# parser.add_argument("--test", action="store_true", dest="test") - -args = parser.parse_args() - -args.embedding_model = os.path.join(args.model_path, "embd.h5") -args.clf_model = os.path.join(args.model_path, "clf.h5") -args.train_log = os.path.join(args.model_path, "train.log") -args.h5data = args.train_data + ".h5" +args = arguments.parse() # config = tf.ConfigProto(log_device_placement=True) @@ -125,7 +86,7 @@ def main_hyperband(): params = { # static params "type": ["paul"], - "batch_size": [64], + "batch_size": [args.batch_size], "vocab_size": [len(char_dict) + 1], "window_size": [10], "domain_length": [40], @@ -143,32 +104,35 @@ def main_hyperband(): "dense_main": [16, 32, 64, 128, 256, 512], } param = hyperband.sample_params(params) - print(param) + logger.info(param) - print("create training dataset") + logger.info("create training dataset") domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict, max_len=args.domain_length, window_size=args.window) - hp = hyperband.Hyperband(params, [domain_tr, flow_tr], [client_tr, server_tr]) - hp.run() + hp = hyperband.Hyperband(params, + [domain_tr, flow_tr], + [client_tr, server_tr]) + results = hp.run() + json.dump(results, open("hyperband.json")) def load_or_generate_h5data(h5data, train_data, domain_length, window_size): char_dict = dataset.get_character_dict() - print("check for h5data", h5data) + logger.info(f"check for h5data {h5data}") try: open(h5data, "r") except FileNotFoundError: - print("h5 data not found - load csv file") + logger.info("h5 data not found - load csv file") user_flow_df = dataset.get_user_flow_data(train_data) - print("create training dataset") + logger.info("create training dataset") domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict, max_len=domain_length, window_size=window_size) - print("store training dataset as h5 file") + logger.info("store training dataset as h5 file") dataset.store_h5dataset(args.h5data, domain_tr, flow_tr, client_tr, server_tr) - print("load h5 dataset") + logger.info("load h5 dataset") return dataset.load_h5dataset(h5data) @@ -204,7 +168,7 @@ def main_train(): embedding, model = models.get_models_by_params(param) embedding.summary() model.summary() - print("define callbacks") + logger.info("define callbacks") cp = ModelCheckpoint(filepath=args.clf_model, monitor='val_loss', verbose=False, @@ -213,11 +177,11 @@ def main_train(): early = EarlyStopping(monitor='val_loss', patience=5, verbose=False) - print("compile model") + logger.info("compile model") model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) - print("start training") + logger.info("start training") model.fit([domain_tr, flow_tr], [client_tr, server_tr], batch_size=args.batch_size, @@ -225,40 +189,40 @@ def main_train(): callbacks=[cp, csv, early], shuffle=True, validation_split=0.2) - print("save embedding") + logger.info("save embedding") embedding.save(args.embedding_model) def main_test(): domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.h5data, args.train_data, args.domain_length, args.window) - # embedding = load_model(args.embedding_model) clf = load_model(args.clf_model) - loss, _, _, client_acc, server_acc = clf.evaluate([domain_val, flow_val], [client_val, server_val], batch_size=args.batch_size) - - print(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}") + logger.info(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}") + y_pred = clf.predict([domain_val, flow_val], + batch_size=args.batch_size) + np.save(os.path.join(args.model_path, "future_predict.npy"), y_pred) def main_visualization(): mask = dataset.load_mask_eval(args.data, args.test_image) y_pred_path = args.model_path + "pred.npy" - print("plot model") + logger.info("plot model") model = load_model(args.model_path + "model.h5", custom_objects=evaluation.get_metrics()) visualize.plot_model(model, args.model_path + "model.png") - print("plot training curve") + logger.info("plot training curve") logs = pd.read_csv(args.model_path + "train.log") visualize.plot_training_curve(logs, "{}/train.png".format(args.model_path)) pred = np.load(y_pred_path) - print("plot pr curve") + logger.info("plot pr curve") visualize.plot_precision_recall(mask, pred, "{}/prc.png".format(args.model_path)) visualize.plot_precision_recall_curves(mask, pred, "{}/prc2.png".format(args.model_path)) - print("plot roc curve") + logger.info("plot roc curve") visualize.plot_roc_curve(mask, pred, "{}/roc.png".format(args.model_path)) - print("store prediction image") + logger.info("store prediction image") visualize.save_image_as(pred, "{}/pred.png".format(args.model_path))