From d1da3d6ca383db8a2ef39fd43b3a10febc55ce4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Mon, 9 Oct 2017 15:10:15 +0200 Subject: [PATCH] fix model selection --- fancy.sh | 5 +++++ main.py | 1 - models/__init__.py | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/fancy.sh b/fancy.sh index cdb05ff..30cc4ff 100644 --- a/fancy.sh +++ b/fancy.sh @@ -22,3 +22,8 @@ python3 main.py --mode beta --batch 1024 --models ${RESDIR}/client_final_*/ --da python3 main.py --mode beta --batch 1024 --models ${RESDIR}/both_final_*/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_final python3 main.py --mode beta --batch 1024 --models ${RESDIR}/both_inter_*/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_inter python3 main.py --mode beta --batch 1024 --models ${RESDIR}/both_staggered_*/ --data ${DATADIR} --model_output both --out-prefix ${RESDIR}/both_staggered + +python3 main.py --mode embedding --batch 1024 --model ${RESDIR}/client_final_*/ --data ${DATADIR} --model_output client --out-prefix --model ${RESDIR}/client_final +python3 main.py --mode embedding --batch 1024 --model ${RESDIR}/both_final_*/ --data ${DATADIR} --model_output both --out-prefix --model ${RESDIR}/both_final +python3 main.py --mode embedding --batch 1024 --model ${RESDIR}/both_inter_*/ --data ${DATADIR} --model_output both --out-prefix --model ${RESDIR}/both_inter +python3 main.py --mode embedding --batch 1024 --model ${RESDIR}/both_staggered_*/ --data ${DATADIR} --model_output both --out-prefix --model ${RESDIR}/both_staggered \ No newline at end of file diff --git a/main.py b/main.py index 89580c3..6a2efe2 100644 --- a/main.py +++ b/main.py @@ -63,7 +63,6 @@ PARAMS = { "flow_features": 3, # 'dropout': 0.5, # currently fix - 'domain_features': args.domain_embedding, 'embedding': args.embedding, 'flow_features': 3, 'filter_embedding': args.filter_embedding, diff --git a/models/__init__.py b/models/__init__.py index 3cd662b..f9384be 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -29,6 +29,8 @@ def get_models_by_params(params: dict): elif network_depth == "flat2": networks = flat_2 elif network_depth == "deep1": + networks = deep1 + elif network_depth == "deep2": networks = renes_networks else: raise Exception("network not found")