diff --git a/Makefile b/Makefile index 892ef15..c5d817c 100644 --- a/Makefile +++ b/Makefile @@ -1,27 +1,27 @@ run: - python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client --epochs 2 --depth flat1 \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ - --dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client --runs 1 + python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client --epochs 2 \ + --filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \ + --dense_embd 8 --domain_embd 8 --batch 64 --type final --model_output client --runs 1 - python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_final --epochs 2 --depth flat1 \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ - --dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output both --runs 1 + python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_final --epochs 2 \ + --filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \ + --dense_embd 8 --domain_embd 8 --batch 64 --type final --model_output both --runs 1 - python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_inter --epochs 2 --depth flat1 \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ - --dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output both --runs 1 + python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_inter --epochs 2 \ + --filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \ + --dense_embd 8 --domain_embd 8 --batch 64 --type inter --model_output both --runs 1 - python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_soft --epochs 2 --depth flat1 \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ - --dense_embd 16 --domain_embd 8 --batch 64 --type soft --model_output both --runs 1 + python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_soft --epochs 2 \ + --filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \ + --dense_embd 8 --domain_embd 8 --batch 64 --type soft --model_output both --runs 1 - python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_long --epochs 2 --depth flat1 \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ - --dense_embd 16 --domain_embd 8 --batch 64 --type long --model_output both --runs 1 + python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_long --epochs 2 \ + --filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \ + --dense_embd 8 --domain_embd 8 --batch 64 --type long --model_output both --runs 1 - python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_staggered --epochs 2 --depth flat1 \ - --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ - --dense_embd 16 --domain_embd 8 --batch 64 --type staggered --model_output both --runs 1 + python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_staggered --epochs 2 \ + --filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \ + --dense_embd 8 --domain_embd 8 --batch 64 --type staggered --model_output both --runs 1 test: diff --git a/main.py b/main.py index 78cb65a..3a9b9ef 100644 --- a/main.py +++ b/main.py @@ -163,11 +163,9 @@ def main_hyperband(data, domain_length, window_size, model_type, result_file, ma return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, max_iter, result_file) -def run_hyperband(dist_size, domain, flow, client, server, max_iter, savefile): +def run_hyperband(dist_size, features, labels, max_iter, savefile): param_dist = get_param_dist(dist_size) - hp = hyperband.Hyperband(param_dist, - [domain, flow], - [client, server], + hp = hyperband.Hyperband(param_dist, features, labels, max_iter=max_iter, savefile=savefile) results = hp.run() @@ -191,7 +189,27 @@ def load_data(data, domain_length, window_size, model_type, shuffled=False): return domain_tr, flow_tr, client_tr, server_tr -def get_weighting(class_weights, sample_weights, client, server): +def load_training_data(data, model_output, domain_length, window_size, model_type, shuffled=False): + domain_tr, flow_tr, client_tr, server_tr = load_data(data, domain_length, + window_size, model_type, shuffled) + features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value} + if model_output == "both": + labels = {"client": client_tr.value, "server": server_tr} + loss_weights = {"client": 1.0, "server": 1.0} + elif model_output == "client": + labels = {"client": client_tr.value} + loss_weights = {"client": 1.0} + elif model_output == "server": + labels = {"server": server_tr} + loss_weights = {"server": 1.0} + else: + raise ValueError("unknown model output") + return features, labels, loss_weights + + +def get_weighting(class_weights, sample_weights, labels): + return None, None + client, server = labels["client"], labels["server"] if class_weights: logger.info("class weights: compute custom weights") custom_class_weights = get_custom_class_weights(client, server) @@ -217,16 +235,16 @@ def main_train(param=None): logger.info(f"Use command line arguments: {args}") # data preparation - domain_tr, flow_tr, client_tr, server_tr = load_data(args.data, args.domain_length, - args.window, args.model_type) + features, labels, loss_weights = load_training_data(args.data, args.model_output, args.domain_length, + args.window, args.model_type) - # call hyperband if used + # call hyperband if results are not accessible if args.hyperband_results: try: hyper_results = joblib.load(args.hyperband_results) except Exception: logger.info("start hyperband parameter search") - hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, args.hyper_max_iter, + hyper_results = run_hyperband("small", features, labels, args.hyper_max_iter, args.hyperband_results) param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"] param["type"] = args.model_type @@ -235,8 +253,8 @@ def main_train(param=None): param = PARAMS # custom class or sample weights - custom_class_weights, custom_sample_weights = get_weighting(args.class_weights, args.sample_weights, - client_tr.value, server_tr) + # TODO: should throw an error when using weights with only the client labels + custom_class_weights, custom_sample_weights = get_weighting(args.class_weights, args.sample_weights, labels) for i in range(args.runs): model_path = os.path.join(args.model_path, f"clf_{i}.h5") @@ -259,19 +277,6 @@ def main_train(param=None): logger.info(f"Generator model with params: {param}") model = models.get_models_by_params(param) - features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value} - if args.model_output == "both": - labels = {"client": client_tr.value, "server": server_tr} - loss_weights = {"client": 1.0, "server": 1.0} - elif args.model_output == "client": - labels = {"client": client_tr.value} - loss_weights = {"client": 1.0} - elif args.model_output == "server": - labels = {"server": server_tr} - loss_weights = {"server": 1.0} - else: - raise ValueError("unknown model output") - logger.info(f"select model: {args.model_type}") if args.model_type == "staggered": logger.info("compile and pre-train server model")