refactor class weights

This commit is contained in:
René Knaebel 2017-11-10 14:31:32 +01:00
parent 461d4cab8f
commit d58dbcb101
2 changed files with 47 additions and 42 deletions

View File

@ -1,27 +1,27 @@
run: run:
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client --epochs 2 --depth flat1 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client --epochs 2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client --runs 1 --dense_embd 8 --domain_embd 8 --batch 64 --type final --model_output client --runs 1
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_final --epochs 2 --depth flat1 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_final --epochs 2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output both --runs 1 --dense_embd 8 --domain_embd 8 --batch 64 --type final --model_output both --runs 1
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_inter --epochs 2 --depth flat1 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_inter --epochs 2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
--dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output both --runs 1 --dense_embd 8 --domain_embd 8 --batch 64 --type inter --model_output both --runs 1
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_soft --epochs 2 --depth flat1 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_soft --epochs 2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
--dense_embd 16 --domain_embd 8 --batch 64 --type soft --model_output both --runs 1 --dense_embd 8 --domain_embd 8 --batch 64 --type soft --model_output both --runs 1
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_long --epochs 2 --depth flat1 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_long --epochs 2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
--dense_embd 16 --domain_embd 8 --batch 64 --type long --model_output both --runs 1 --dense_embd 8 --domain_embd 8 --batch 64 --type long --model_output both --runs 1
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_staggered --epochs 2 --depth flat1 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_staggered --epochs 2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
--dense_embd 16 --domain_embd 8 --batch 64 --type staggered --model_output both --runs 1 --dense_embd 8 --domain_embd 8 --batch 64 --type staggered --model_output both --runs 1
test: test:

53
main.py
View File

@ -163,11 +163,9 @@ def main_hyperband(data, domain_length, window_size, model_type, result_file, ma
return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, max_iter, result_file) return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, max_iter, result_file)
def run_hyperband(dist_size, domain, flow, client, server, max_iter, savefile): def run_hyperband(dist_size, features, labels, max_iter, savefile):
param_dist = get_param_dist(dist_size) param_dist = get_param_dist(dist_size)
hp = hyperband.Hyperband(param_dist, hp = hyperband.Hyperband(param_dist, features, labels,
[domain, flow],
[client, server],
max_iter=max_iter, max_iter=max_iter,
savefile=savefile) savefile=savefile)
results = hp.run() results = hp.run()
@ -191,7 +189,27 @@ def load_data(data, domain_length, window_size, model_type, shuffled=False):
return domain_tr, flow_tr, client_tr, server_tr return domain_tr, flow_tr, client_tr, server_tr
def get_weighting(class_weights, sample_weights, client, server): def load_training_data(data, model_output, domain_length, window_size, model_type, shuffled=False):
domain_tr, flow_tr, client_tr, server_tr = load_data(data, domain_length,
window_size, model_type, shuffled)
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
if model_output == "both":
labels = {"client": client_tr.value, "server": server_tr}
loss_weights = {"client": 1.0, "server": 1.0}
elif model_output == "client":
labels = {"client": client_tr.value}
loss_weights = {"client": 1.0}
elif model_output == "server":
labels = {"server": server_tr}
loss_weights = {"server": 1.0}
else:
raise ValueError("unknown model output")
return features, labels, loss_weights
def get_weighting(class_weights, sample_weights, labels):
return None, None
client, server = labels["client"], labels["server"]
if class_weights: if class_weights:
logger.info("class weights: compute custom weights") logger.info("class weights: compute custom weights")
custom_class_weights = get_custom_class_weights(client, server) custom_class_weights = get_custom_class_weights(client, server)
@ -217,16 +235,16 @@ def main_train(param=None):
logger.info(f"Use command line arguments: {args}") logger.info(f"Use command line arguments: {args}")
# data preparation # data preparation
domain_tr, flow_tr, client_tr, server_tr = load_data(args.data, args.domain_length, features, labels, loss_weights = load_training_data(args.data, args.model_output, args.domain_length,
args.window, args.model_type) args.window, args.model_type)
# call hyperband if used # call hyperband if results are not accessible
if args.hyperband_results: if args.hyperband_results:
try: try:
hyper_results = joblib.load(args.hyperband_results) hyper_results = joblib.load(args.hyperband_results)
except Exception: except Exception:
logger.info("start hyperband parameter search") logger.info("start hyperband parameter search")
hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, args.hyper_max_iter, hyper_results = run_hyperband("small", features, labels, args.hyper_max_iter,
args.hyperband_results) args.hyperband_results)
param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"] param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"]
param["type"] = args.model_type param["type"] = args.model_type
@ -235,8 +253,8 @@ def main_train(param=None):
param = PARAMS param = PARAMS
# custom class or sample weights # custom class or sample weights
custom_class_weights, custom_sample_weights = get_weighting(args.class_weights, args.sample_weights, # TODO: should throw an error when using weights with only the client labels
client_tr.value, server_tr) custom_class_weights, custom_sample_weights = get_weighting(args.class_weights, args.sample_weights, labels)
for i in range(args.runs): for i in range(args.runs):
model_path = os.path.join(args.model_path, f"clf_{i}.h5") model_path = os.path.join(args.model_path, f"clf_{i}.h5")
@ -259,19 +277,6 @@ def main_train(param=None):
logger.info(f"Generator model with params: {param}") logger.info(f"Generator model with params: {param}")
model = models.get_models_by_params(param) model = models.get_models_by_params(param)
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
if args.model_output == "both":
labels = {"client": client_tr.value, "server": server_tr}
loss_weights = {"client": 1.0, "server": 1.0}
elif args.model_output == "client":
labels = {"client": client_tr.value}
loss_weights = {"client": 1.0}
elif args.model_output == "server":
labels = {"server": server_tr}
loss_weights = {"server": 1.0}
else:
raise ValueError("unknown model output")
logger.info(f"select model: {args.model_type}") logger.info(f"select model: {args.model_type}")
if args.model_type == "staggered": if args.model_type == "staggered":
logger.info("compile and pre-train server model") logger.info("compile and pre-train server model")