From 903e81c931d20a075ee7b3b917463fc656b06d6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Tue, 7 Nov 2017 20:47:41 +0100 Subject: [PATCH] remove identical parameter from data loading function; add runs argument --- Makefile | 12 ++++++------ arguments.py | 4 ++++ dataset.py | 18 +++++++++--------- main.py | 29 +++++++++-------------------- models/__init__.py | 2 +- server.py | 11 ++--------- 6 files changed, 31 insertions(+), 45 deletions(-) diff --git a/Makefile b/Makefile index c953a1a..892ef15 100644 --- a/Makefile +++ b/Makefile @@ -1,27 +1,27 @@ run: python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client --epochs 2 --depth flat1 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ - --dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client + --dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client --runs 1 python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_final --epochs 2 --depth flat1 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ - --dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output both + --dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output both --runs 1 python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_inter --epochs 2 --depth flat1 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ - --dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output both + --dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output both --runs 1 python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_soft --epochs 2 --depth flat1 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ - --dense_embd 16 --domain_embd 8 --batch 64 --type soft --model_output both + --dense_embd 16 --domain_embd 8 --batch 64 --type soft --model_output both --runs 1 python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_long --epochs 2 --depth flat1 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ - --dense_embd 16 --domain_embd 8 --batch 64 --type long --model_output both + --dense_embd 16 --domain_embd 8 --batch 64 --type long --model_output both --runs 1 python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_staggered --epochs 2 --depth flat1 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ - --dense_embd 16 --domain_embd 8 --batch 64 --type staggered --model_output both + --dense_embd 16 --domain_embd 8 --batch 64 --type staggered --model_output both --runs 1 test: diff --git a/arguments.py b/arguments.py index ee3e512..d24b37e 100644 --- a/arguments.py +++ b/arguments.py @@ -49,6 +49,10 @@ parser.add_argument("--epochs", action="store", dest="epochs", parser.add_argument("--init_epoch", action="store", dest="initial_epoch", default=0, type=int) +parser.add_argument("--runs", action="store", dest="runs", + default=20, type=int) + + # parser.add_argument("--samples", action="store", dest="samples", # default=100000, type=int) # diff --git a/dataset.py b/dataset.py index 9901589..20482fe 100644 --- a/dataset.py +++ b/dataset.py @@ -193,14 +193,14 @@ def get_flow_per_user(df): yield df.loc[df.user_hash == user].dropna(axis=0, how="any") -def load_or_generate_h5data(h5data, train_data, domain_length, window_size): - logger.info(f"check for h5data {h5data}") +def load_or_generate_h5data(train_data, domain_length, window_size): + logger.info(f"check for h5data {train_data}") try: - check_h5dataset(h5data) + check_h5dataset(train_data) except FileNotFoundError: logger.info("load raw training dataset") - domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data, train_data, - domain_length, window_size) + domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(train_data, domain_length, + window_size) logger.info("filter training dataset") domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value, name.value, hits.value, @@ -213,14 +213,14 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size): "client": client.astype(np.bool), "server": server.astype(np.bool) } - store_h5dataset(h5data, data) + store_h5dataset(train_data, data) logger.info("load h5 dataset") - data = load_h5dataset(h5data) + data = load_h5dataset(train_data) return data["domain"], data["flow"], data["name"], data["client"], data["server"] -def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size): - h5data = h5data + "_raw" +def load_or_generate_raw_h5data(train_data, domain_length, window_size): + h5data = train_data + "_raw" logger.info(f"check for h5data {h5data}") try: check_h5dataset(h5data) diff --git a/main.py b/main.py index c5f5dce..e95c7cb 100644 --- a/main.py +++ b/main.py @@ -59,7 +59,7 @@ if args.gpu: # default parameter PARAMS = { "type": args.model_type, - "depth": args.model_depth, + # "depth": args.model_depth, "batch_size": args.batch_size, "window_size": args.window, "domain_length": args.domain_length, @@ -84,7 +84,7 @@ def get_param_dist(dist_size="small"): return { # static params "type": [args.model_type], - "depth": [args.model_depth], + # "depth": [args.model_depth], "model_output": [args.model_output], "batch_size": [args.batch_size], "window_size": [args.window], @@ -104,7 +104,7 @@ def get_param_dist(dist_size="small"): return { # static params "type": [args.model_type], - "depth": [args.model_depth], + # "depth": [args.model_depth], "model_output": [args.model_output], "batch_size": [args.batch_size], "window_size": [args.window], @@ -159,9 +159,7 @@ def main_hyperband(data, domain_length, window_size, model_type, result_file, di param_dist = get_param_dist(dist_size) logger.info("create training dataset") - domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data, - data, - domain_length, + domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data, domain_length, window) server_tr = np.max(server_windows_tr, axis=1) @@ -191,9 +189,7 @@ def train(parameters, features, labels): def load_data(data, domain_length, window_size, model_type): # data preparation - domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data, - data, - domain_length, + domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data, domain_length, window_size) server_tr = np.max(server_windows_tr, axis=1) if model_type in ("inter", "staggered"): @@ -218,10 +214,10 @@ def main_train(param=None): logger.info(f"select params from result: {param}") if not param: param = PARAMS - - for i in range(20): + + for i in range(args.runs): model_path = os.path.join(args.model_path, f"clf_{i}.h5") - train_log_path = os.path.join(args.model_path, "train_{i}.log.csv") + train_log_path = os.path.join(args.model_path, f"train_{i}.log.csv") # define training call backs logger.info("define callbacks") callbacks = [] @@ -318,7 +314,6 @@ def main_retrain(): exists_or_make_path(args.model_destination) domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data, - args.data, args.domain_length, args.window) logger.info("define callbacks") @@ -373,10 +368,7 @@ def main_retrain(): def main_test(): logger.info("start test: load data") - domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, - args.data, - args.domain_length, - args.window) + domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window) domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length) for model_args in get_model_args(args): @@ -425,7 +417,6 @@ def main_visualization(): visualize.plot_save("{}/{}_{}.pdf".format(model_path, aggregation, curve)) _, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data, - args.data, args.domain_length, args.window) @@ -484,7 +475,6 @@ def main_visualization(): def main_visualize_all(): _, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data, - args.data, args.domain_length, args.window) @@ -600,7 +590,6 @@ def main_visualize_all_embds(): def main_beta(): domain_val, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data, - args.data, args.domain_length, args.window) path, model_prefix = os.path.split(os.path.normpath(args.output_prefix)) diff --git a/models/__init__.py b/models/__init__.py index 1af6697..317bb0d 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -17,7 +17,7 @@ def get_models_by_params(params: dict): # decomposing param section # mainly embedding model network_type = params.get("type") - network_depth = params.get("depth") + # network_depth = params.get("depth") embedding_size = params.get("embedding") filter_embedding = params.get("filter_embedding") kernel_embedding = params.get("kernel_embedding") diff --git a/server.py b/server.py index d89ea80..cb4f98f 100644 --- a/server.py +++ b/server.py @@ -21,7 +21,6 @@ def train_server_only(params): logger.info(f"Use command line arguments: {args}") domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data, - args.data, args.domain_length, args.window) domain_tr = domain_tr.value.reshape(-1, 40) @@ -69,10 +68,7 @@ def train_server_only(params): def test_server_only(): logger.info("start test: load data") - domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, - args.data, - args.domain_length, - args.window) + domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window) domain_val = domain_val.value.reshape(-1, 40) flow_val = flow_val.value.reshape(-1, 3) domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length) @@ -102,10 +98,7 @@ def vis_server(): return emdb, clf domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data( - args.data, - args.data, - args.domain_length, - args.window) + args.data, args.domain_length, args.window) results = dataset.load_predictions(args.clf_model)