From edc75f4f44702a00827da11fcdaa12fc5ac6a890 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Fri, 8 Sep 2017 17:11:13 +0200 Subject: [PATCH] refactor dataset creation, split up functions --- Makefile | 66 ++++++++++++++++++++++++++++++++++++++++-------------- dataset.py | 61 ++++++++++++++++++------------------------------- main.py | 2 +- 3 files changed, 72 insertions(+), 57 deletions(-) diff --git a/Makefile b/Makefile index 4984f3e..786ea62 100644 --- a/Makefile +++ b/Makefile @@ -1,37 +1,69 @@ run: - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test1 --epochs 2 --depth small \ - --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_1 --epochs 2 --depth small \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 2 --depth small \ - --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_2 --epochs 2 --depth small \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test3 --epochs 2 --depth medium \ - --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_3 --epochs 2 --depth medium \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test4 --epochs 2 --depth medium \ - --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_4 --epochs 2 --depth medium \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test5 --epochs 2 --depth small \ - --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_5 --epochs 2 --depth small \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered --model_output both + + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_1 --epochs 2 --depth small \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client + + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_2 --epochs 2 --depth small \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client + + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_3 --epochs 2 --depth medium \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client + + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_4 --epochs 2 --depth medium \ + --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \ + --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client test: - python3 main.py --mode test --batch 128 --models results/test* --test data/rk_mini.csv.gz + python3 main.py --mode test --batch 128 --models results/test/test_both_* --test data/rk_mini.csv.gz --model_output both + python3 main.py --mode test --batch 128 --models results/test/test_client_* --test data/rk_mini.csv.gz --model_output client fancy: - python3 main.py --mode fancy --batch 128 --model results/test1 --test data/rk_mini.csv.gz + python3 main.py --mode fancy --batch 128 --model results/test/test_both_1 --test data/rk_mini.csv.gz - python3 main.py --mode fancy --batch 128 --model results/test2 --test data/rk_mini.csv.gz + python3 main.py --mode fancy --batch 128 --model results/test/test_both_2 --test data/rk_mini.csv.gz - python3 main.py --mode fancy --batch 128 --model results/test3 --test data/rk_mini.csv.gz + python3 main.py --mode fancy --batch 128 --model results/test/test_both_3 --test data/rk_mini.csv.gz - python3 main.py --mode fancy --batch 128 --model results/test4 --test data/rk_mini.csv.gz + python3 main.py --mode fancy --batch 128 --model results/test/test_both_4 --test data/rk_mini.csv.gz + + python3 main.py --mode fancy --batch 128 --model results/test/test_both_5 --test data/rk_mini.csv.gz + + python3 main.py --mode fancy --batch 128 --model results/test/test_client_1 --test data/rk_mini.csv.gz + + python3 main.py --mode fancy --batch 128 --model results/test/test_client_2 --test data/rk_mini.csv.gz + + python3 main.py --mode fancy --batch 128 --model results/test/test_client_3 --test data/rk_mini.csv.gz + + python3 main.py --mode fancy --batch 128 --model results/test/test_client_4 --test data/rk_mini.csv.gz all-fancy: - python3 main.py --mode all_fancy --batch 128 --models results/test* --test data/rk_mini.csv.gz + python3 main.py --mode all_fancy --batch 128 --models results/test/test* --test data/rk_mini.csv.gz hyper: python3 main.py --mode hyperband --batch 64 --train data/rk_data.csv.gz clean: - rm -r results/test* + rm -r results/test/test* rm data/rk_mini.csv.gz.h5 diff --git a/dataset.py b/dataset.py index a9c6851..9cf22f0 100644 --- a/dataset.py +++ b/dataset.py @@ -99,18 +99,13 @@ def get_all_flow_features(features): def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10): - logger.info("get chunks from user data frames") - with Pool() as pool: - results = [] - for user_flow in tqdm(get_flow_per_user(user_flow_df), total=len(user_flow_df['user_hash'].unique().tolist())): - results.append(pool.apply_async(get_user_chunks, (user_flow, window_size))) - windows = [window for res in results for window in res.get()] - logger.info("create training dataset") - domain, flow, hits, names, server, trusted_hits = create_dataset_from_lists(chunks=windows, - vocab=char_dict, - max_len=max_len) - # make client labels discrete with 4 different values - hits = np.apply_along_axis(lambda x: discretize_label(x, 3), 0, np.atleast_2d(hits)) + domain, flow, name, hits, trusted_hits, server = create_raw_dataset_from_flows(user_flow_df, char_dict, + max_len, window_size) + domain, flow, name, client, server = filter_window_dataset_by_hits(domain, flow, name, hits, trusted_hits, server) + return domain, flow, name, client, server + + +def filter_window_dataset_by_hits(domain, flow, name, hits, trusted_hits, server): # select only 1.0 and 0.0 from training data pos_idx = np.where(np.logical_or(hits == 1.0, trusted_hits >= 1.0))[0] neg_idx = np.where(hits == 0.0)[0] @@ -118,15 +113,15 @@ def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10): # choose selected sample to train on domain = domain[idx] flow = flow[idx] - client_tr = np.zeros_like(idx, float) - client_tr[:pos_idx.shape[-1]] = 1.0 + client = np.zeros_like(idx, float) + client[:pos_idx.shape[-1]] = 1.0 server = server[idx] - names = names[idx] + name = name[idx] - return domain, flow, names, client_tr, server + return domain, flow, name, client, server -def create_testset_from_flows(user_flow_df, char_dict, max_len, window_size=10): +def create_raw_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10): logger.info("get chunks from user data frames") with Pool() as pool: results = [] @@ -134,24 +129,13 @@ def create_testset_from_flows(user_flow_df, char_dict, max_len, window_size=10): results.append(pool.apply_async(get_user_chunks, (user_flow, window_size))) windows = [window for res in results for window in res.get()] logger.info("create training dataset") - domain, flow, hits, names, server, trusted_hits = create_dataset_from_lists(chunks=windows, - vocab=char_dict, - max_len=max_len) + domain, flow, hits, name, server, trusted_hits = create_dataset_from_windows(chunks=windows, + vocab=char_dict, + max_len=max_len) # make client labels discrete with 4 different values - hits = np.apply_along_axis(lambda x: discretize_label(x, 3), 0, np.atleast_2d(hits)) - # select only 1.0 and 0.0 from training data - pos_idx = np.where(np.logical_or(hits == 1.0, trusted_hits >= 1.0))[0] - neg_idx = np.where(hits == 0.0)[0] - idx = np.concatenate((pos_idx, neg_idx)) - # choose selected sample to train on - domain = domain[idx] - flow = flow[idx] - client_tr = np.zeros_like(idx, float) - client_tr[:pos_idx.shape[-1]] = 1.0 - server = server[idx] - names = names[idx] + hits = np.apply_along_axis(lambda x: make_label_discrete(x, 3), 0, np.atleast_2d(hits)) - return domain, flow, names, client_tr, server + return domain, flow, name, hits, trusted_hits, server def store_h5dataset(path, data: dict): @@ -163,14 +147,13 @@ def store_h5dataset(path, data: dict): def load_h5dataset(path): f = h5py.File(path, "r") - keys = f.keys() data = {} - for k in keys: + for k in f.keys(): data[k] = f[k] return data -def create_dataset_from_lists(chunks, vocab, max_len): +def create_dataset_from_windows(chunks, vocab, max_len): """ combines domain and feature windows to sequential training data :param chunks: list of flow feature windows @@ -204,7 +187,7 @@ def create_dataset_from_lists(chunks, vocab, max_len): hits, names, servers, trusted_hits) -def discretize_label(values, threshold): +def make_label_discrete(values, threshold): max_val = np.max(values) if max_val >= threshold: return 1.0 @@ -251,8 +234,8 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size): user_flow_df = get_user_flow_data(train_data) logger.info("create training dataset") domain, flow, name, client, server = create_dataset_from_flows(user_flow_df, char_dict, - max_len=domain_length, - window_size=window_size) + max_len=domain_length, + window_size=window_size) logger.info("store training dataset as h5 file") data = { "domain": domain.astype(np.int8), diff --git a/main.py b/main.py index 2fe7467..08f88af 100644 --- a/main.py +++ b/main.py @@ -275,7 +275,7 @@ def main_test(): # np.save(model_args["model_path"] + "/domain_embds.npy", domain_embeddings) results["domain_embds"] = domain_embeddings - joblib.dump(results, model_args["model_path"] + "results.joblib", compress=3) + joblib.dump(results, model_args["model_path"] + "/results.joblib", compress=3) def main_visualization():