change argument interface
- add more properties for network specification - change names for consistency
This commit is contained in:
parent
71f218888d
commit
595c2ea894
10
Makefile
10
Makefile
@ -1,18 +1,18 @@
|
||||
run:
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test1 --epochs 2 --depth small \
|
||||
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type final
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final
|
||||
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 2 --depth small \
|
||||
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type inter
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter
|
||||
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test3 --epochs 2 --depth medium \
|
||||
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type final
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final
|
||||
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test4 --epochs 2 --depth medium \
|
||||
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type inter
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter
|
||||
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test5 --epochs 2 --depth small \
|
||||
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered
|
||||
|
||||
test:
|
||||
python3 main.py --mode test --batch 128 --models results/test* --test data/rk_mini.csv.gz
|
||||
|
20
arguments.py
20
arguments.py
@ -46,8 +46,24 @@ parser.add_argument("--epochs", action="store", dest="epochs",
|
||||
parser.add_argument("--embd", action="store", dest="embedding",
|
||||
default=128, type=int)
|
||||
|
||||
parser.add_argument("--hidden_char_dims", action="store", dest="hidden_char_dims",
|
||||
default=256, type=int)
|
||||
parser.add_argument("--filter_embd", action="store", dest="filter_embedding",
|
||||
default=128, type=int)
|
||||
|
||||
parser.add_argument("--dense_embd", action="store", dest="dense_embedding",
|
||||
default=128, type=int)
|
||||
|
||||
parser.add_argument("--kernel_embd", action="store", dest="kernel_embedding",
|
||||
default=3, type=int)
|
||||
|
||||
parser.add_argument("--filter_main", action="store", dest="filter_main",
|
||||
default=128, type=int)
|
||||
|
||||
parser.add_argument("--dense_main", action="store", dest="dense_main",
|
||||
default=128, type=int)
|
||||
|
||||
parser.add_argument("--kernel_main", action="store", dest="kernel_main",
|
||||
default=3, type=int)
|
||||
|
||||
|
||||
parser.add_argument("--window", action="store", dest="window",
|
||||
default=10, type=int)
|
||||
|
16
main.py
16
main.py
@ -58,21 +58,21 @@ if args.gpu:
|
||||
PARAMS = {
|
||||
"type": args.model_type,
|
||||
"depth": args.model_depth,
|
||||
"batch_size": 64,
|
||||
# "batch_size": 64,
|
||||
"window_size": args.window,
|
||||
"domain_length": args.domain_length,
|
||||
"flow_features": 3,
|
||||
#
|
||||
'dropout': 0.5,
|
||||
'dropout': 0.5, # currently fix
|
||||
'domain_features': args.domain_embedding,
|
||||
'embedding_size': args.embedding,
|
||||
'flow_features': 3,
|
||||
'filter_embedding': args.hidden_char_dims,
|
||||
'hidden_embedding': args.domain_embedding,
|
||||
'kernel_embedding': 3,
|
||||
'filter_main': 128,
|
||||
'dense_main': 128,
|
||||
'kernels_main': 3,
|
||||
'filter_embedding': args.filter_embedding,
|
||||
'dense_embedding': args.dense_embedding,
|
||||
'kernel_embedding': args.kernel_embedding,
|
||||
'filter_main': args.filter_main,
|
||||
'dense_main': args.dense_main,
|
||||
'kernel_main': args.kernel_main,
|
||||
'input_length': 40,
|
||||
'model_output': args.model_output
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ def get_models_by_params(params: dict):
|
||||
input_length = params.get("input_length")
|
||||
filter_embedding = params.get("filter_embedding")
|
||||
kernel_embedding = params.get("kernel_embedding")
|
||||
hidden_embedding = params.get("hidden_embedding")
|
||||
hidden_embedding = params.get("dense_embedding")
|
||||
dropout = params.get("dropout")
|
||||
# mainly prediction model
|
||||
flow_features = params.get("flow_features")
|
||||
@ -21,7 +21,7 @@ def get_models_by_params(params: dict):
|
||||
window_size = params.get("window_size")
|
||||
domain_length = params.get("domain_length")
|
||||
filter_main = params.get("filter_main")
|
||||
kernel_main = params.get("kernels_main")
|
||||
kernel_main = params.get("kernel_main")
|
||||
dense_dim = params.get("dense_main")
|
||||
model_output = params.get("model_output", "both")
|
||||
# create models
|
||||
@ -32,12 +32,12 @@ def get_models_by_params(params: dict):
|
||||
else:
|
||||
raise Exception("network not found")
|
||||
embedding_model = networks.get_embedding(embedding_size, input_length, filter_embedding, kernel_embedding,
|
||||
hidden_embedding, dropout)
|
||||
hidden_embedding, 0.5)
|
||||
|
||||
old_model = networks.get_model(dropout, flow_features, domain_features, window_size, domain_length,
|
||||
old_model = networks.get_model(0.25, flow_features, domain_features, window_size, domain_length,
|
||||
filter_main, kernel_main, dense_dim, embedding_model, model_output)
|
||||
|
||||
new_model = networks.get_new_model(dropout, flow_features, domain_features, window_size, domain_length,
|
||||
new_model = networks.get_new_model(0.25, flow_features, domain_features, window_size, domain_length,
|
||||
filter_main, kernel_main, dense_dim, embedding_model, model_output)
|
||||
|
||||
return embedding_model, old_model, new_model
|
||||
|
10
run.sh
10
run.sh
@ -16,9 +16,10 @@ do
|
||||
--train ${DATADIR}/currentData.csv \
|
||||
--model ${RESDIR}/${output}_${depth}_${mtype} \
|
||||
--epochs 50 \
|
||||
--embd 64 \
|
||||
--hidden_char_dims 128 \
|
||||
--embd 128 \
|
||||
--filter_embd 256 --kernel_embd 8 --dense_embd 128 \
|
||||
--domain_embd 32 \
|
||||
--filter_main 32 --kernel_main 8 --dense_main 1024 \
|
||||
--batch 256 \
|
||||
--balanced_weights \
|
||||
--model_output ${output} \
|
||||
@ -35,9 +36,10 @@ do
|
||||
--train ${DATADIR}/currentData.csv \
|
||||
--model ${RESDIR}/both_${depth}_inter \
|
||||
--epochs 50 \
|
||||
--embd 64 \
|
||||
--hidden_char_dims 128 \
|
||||
--embd 128 \
|
||||
--filter_embd 256 --kernel_embd 8 --dense_embd 128 \
|
||||
--domain_embd 32 \
|
||||
--filter_main 32 --kernel_main 8 --dense_main 1024 \
|
||||
--batch 256 \
|
||||
--balanced_weights \
|
||||
--model_output both \
|
||||
|
Loading…
x
Reference in New Issue
Block a user