refactor class weights
This commit is contained in:
parent
461d4cab8f
commit
d58dbcb101
36
Makefile
36
Makefile
@ -1,27 +1,27 @@
|
|||||||
run:
|
run:
|
||||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client --epochs 2 --depth flat1 \
|
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client --epochs 2 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
--filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client --runs 1
|
--dense_embd 8 --domain_embd 8 --batch 64 --type final --model_output client --runs 1
|
||||||
|
|
||||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_final --epochs 2 --depth flat1 \
|
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_final --epochs 2 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
--filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output both --runs 1
|
--dense_embd 8 --domain_embd 8 --batch 64 --type final --model_output both --runs 1
|
||||||
|
|
||||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_inter --epochs 2 --depth flat1 \
|
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_inter --epochs 2 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
--filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output both --runs 1
|
--dense_embd 8 --domain_embd 8 --batch 64 --type inter --model_output both --runs 1
|
||||||
|
|
||||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_soft --epochs 2 --depth flat1 \
|
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_soft --epochs 2 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
--filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --type soft --model_output both --runs 1
|
--dense_embd 8 --domain_embd 8 --batch 64 --type soft --model_output both --runs 1
|
||||||
|
|
||||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_long --epochs 2 --depth flat1 \
|
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_long --epochs 2 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
--filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --type long --model_output both --runs 1
|
--dense_embd 8 --domain_embd 8 --batch 64 --type long --model_output both --runs 1
|
||||||
|
|
||||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_staggered --epochs 2 --depth flat1 \
|
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_staggered --epochs 2 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
--filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --type staggered --model_output both --runs 1
|
--dense_embd 8 --domain_embd 8 --batch 64 --type staggered --model_output both --runs 1
|
||||||
|
|
||||||
|
|
||||||
test:
|
test:
|
||||||
|
53
main.py
53
main.py
@ -163,11 +163,9 @@ def main_hyperband(data, domain_length, window_size, model_type, result_file, ma
|
|||||||
return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, max_iter, result_file)
|
return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, max_iter, result_file)
|
||||||
|
|
||||||
|
|
||||||
def run_hyperband(dist_size, domain, flow, client, server, max_iter, savefile):
|
def run_hyperband(dist_size, features, labels, max_iter, savefile):
|
||||||
param_dist = get_param_dist(dist_size)
|
param_dist = get_param_dist(dist_size)
|
||||||
hp = hyperband.Hyperband(param_dist,
|
hp = hyperband.Hyperband(param_dist, features, labels,
|
||||||
[domain, flow],
|
|
||||||
[client, server],
|
|
||||||
max_iter=max_iter,
|
max_iter=max_iter,
|
||||||
savefile=savefile)
|
savefile=savefile)
|
||||||
results = hp.run()
|
results = hp.run()
|
||||||
@ -191,7 +189,27 @@ def load_data(data, domain_length, window_size, model_type, shuffled=False):
|
|||||||
return domain_tr, flow_tr, client_tr, server_tr
|
return domain_tr, flow_tr, client_tr, server_tr
|
||||||
|
|
||||||
|
|
||||||
def get_weighting(class_weights, sample_weights, client, server):
|
def load_training_data(data, model_output, domain_length, window_size, model_type, shuffled=False):
|
||||||
|
domain_tr, flow_tr, client_tr, server_tr = load_data(data, domain_length,
|
||||||
|
window_size, model_type, shuffled)
|
||||||
|
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
|
||||||
|
if model_output == "both":
|
||||||
|
labels = {"client": client_tr.value, "server": server_tr}
|
||||||
|
loss_weights = {"client": 1.0, "server": 1.0}
|
||||||
|
elif model_output == "client":
|
||||||
|
labels = {"client": client_tr.value}
|
||||||
|
loss_weights = {"client": 1.0}
|
||||||
|
elif model_output == "server":
|
||||||
|
labels = {"server": server_tr}
|
||||||
|
loss_weights = {"server": 1.0}
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown model output")
|
||||||
|
return features, labels, loss_weights
|
||||||
|
|
||||||
|
|
||||||
|
def get_weighting(class_weights, sample_weights, labels):
|
||||||
|
return None, None
|
||||||
|
client, server = labels["client"], labels["server"]
|
||||||
if class_weights:
|
if class_weights:
|
||||||
logger.info("class weights: compute custom weights")
|
logger.info("class weights: compute custom weights")
|
||||||
custom_class_weights = get_custom_class_weights(client, server)
|
custom_class_weights = get_custom_class_weights(client, server)
|
||||||
@ -217,16 +235,16 @@ def main_train(param=None):
|
|||||||
logger.info(f"Use command line arguments: {args}")
|
logger.info(f"Use command line arguments: {args}")
|
||||||
|
|
||||||
# data preparation
|
# data preparation
|
||||||
domain_tr, flow_tr, client_tr, server_tr = load_data(args.data, args.domain_length,
|
features, labels, loss_weights = load_training_data(args.data, args.model_output, args.domain_length,
|
||||||
args.window, args.model_type)
|
args.window, args.model_type)
|
||||||
|
|
||||||
# call hyperband if used
|
# call hyperband if results are not accessible
|
||||||
if args.hyperband_results:
|
if args.hyperband_results:
|
||||||
try:
|
try:
|
||||||
hyper_results = joblib.load(args.hyperband_results)
|
hyper_results = joblib.load(args.hyperband_results)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.info("start hyperband parameter search")
|
logger.info("start hyperband parameter search")
|
||||||
hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, args.hyper_max_iter,
|
hyper_results = run_hyperband("small", features, labels, args.hyper_max_iter,
|
||||||
args.hyperband_results)
|
args.hyperband_results)
|
||||||
param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"]
|
param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"]
|
||||||
param["type"] = args.model_type
|
param["type"] = args.model_type
|
||||||
@ -235,8 +253,8 @@ def main_train(param=None):
|
|||||||
param = PARAMS
|
param = PARAMS
|
||||||
|
|
||||||
# custom class or sample weights
|
# custom class or sample weights
|
||||||
custom_class_weights, custom_sample_weights = get_weighting(args.class_weights, args.sample_weights,
|
# TODO: should throw an error when using weights with only the client labels
|
||||||
client_tr.value, server_tr)
|
custom_class_weights, custom_sample_weights = get_weighting(args.class_weights, args.sample_weights, labels)
|
||||||
|
|
||||||
for i in range(args.runs):
|
for i in range(args.runs):
|
||||||
model_path = os.path.join(args.model_path, f"clf_{i}.h5")
|
model_path = os.path.join(args.model_path, f"clf_{i}.h5")
|
||||||
@ -259,19 +277,6 @@ def main_train(param=None):
|
|||||||
logger.info(f"Generator model with params: {param}")
|
logger.info(f"Generator model with params: {param}")
|
||||||
model = models.get_models_by_params(param)
|
model = models.get_models_by_params(param)
|
||||||
|
|
||||||
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
|
|
||||||
if args.model_output == "both":
|
|
||||||
labels = {"client": client_tr.value, "server": server_tr}
|
|
||||||
loss_weights = {"client": 1.0, "server": 1.0}
|
|
||||||
elif args.model_output == "client":
|
|
||||||
labels = {"client": client_tr.value}
|
|
||||||
loss_weights = {"client": 1.0}
|
|
||||||
elif args.model_output == "server":
|
|
||||||
labels = {"server": server_tr}
|
|
||||||
loss_weights = {"server": 1.0}
|
|
||||||
else:
|
|
||||||
raise ValueError("unknown model output")
|
|
||||||
|
|
||||||
logger.info(f"select model: {args.model_type}")
|
logger.info(f"select model: {args.model_type}")
|
||||||
if args.model_type == "staggered":
|
if args.model_type == "staggered":
|
||||||
logger.info("compile and pre-train server model")
|
logger.info("compile and pre-train server model")
|
||||||
|
Loading…
Reference in New Issue
Block a user