change argument interface
- add more properties for network specification - change names for consistency
This commit is contained in:
parent
71f218888d
commit
595c2ea894
10
Makefile
10
Makefile
@ -1,18 +1,18 @@
|
|||||||
run:
|
run:
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test1 --epochs 2 --depth small \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test1 --epochs 2 --depth small \
|
||||||
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type final
|
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final
|
||||||
|
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 2 --depth small \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 2 --depth small \
|
||||||
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type inter
|
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter
|
||||||
|
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test3 --epochs 2 --depth medium \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test3 --epochs 2 --depth medium \
|
||||||
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type final
|
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final
|
||||||
|
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test4 --epochs 2 --depth medium \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test4 --epochs 2 --depth medium \
|
||||||
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type inter
|
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter
|
||||||
|
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test5 --epochs 2 --depth small \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test5 --epochs 2 --depth small \
|
||||||
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered
|
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered
|
||||||
|
|
||||||
test:
|
test:
|
||||||
python3 main.py --mode test --batch 128 --models results/test* --test data/rk_mini.csv.gz
|
python3 main.py --mode test --batch 128 --models results/test* --test data/rk_mini.csv.gz
|
||||||
|
20
arguments.py
20
arguments.py
@ -46,8 +46,24 @@ parser.add_argument("--epochs", action="store", dest="epochs",
|
|||||||
parser.add_argument("--embd", action="store", dest="embedding",
|
parser.add_argument("--embd", action="store", dest="embedding",
|
||||||
default=128, type=int)
|
default=128, type=int)
|
||||||
|
|
||||||
parser.add_argument("--hidden_char_dims", action="store", dest="hidden_char_dims",
|
parser.add_argument("--filter_embd", action="store", dest="filter_embedding",
|
||||||
default=256, type=int)
|
default=128, type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--dense_embd", action="store", dest="dense_embedding",
|
||||||
|
default=128, type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--kernel_embd", action="store", dest="kernel_embedding",
|
||||||
|
default=3, type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--filter_main", action="store", dest="filter_main",
|
||||||
|
default=128, type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--dense_main", action="store", dest="dense_main",
|
||||||
|
default=128, type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--kernel_main", action="store", dest="kernel_main",
|
||||||
|
default=3, type=int)
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument("--window", action="store", dest="window",
|
parser.add_argument("--window", action="store", dest="window",
|
||||||
default=10, type=int)
|
default=10, type=int)
|
||||||
|
16
main.py
16
main.py
@ -58,21 +58,21 @@ if args.gpu:
|
|||||||
PARAMS = {
|
PARAMS = {
|
||||||
"type": args.model_type,
|
"type": args.model_type,
|
||||||
"depth": args.model_depth,
|
"depth": args.model_depth,
|
||||||
"batch_size": 64,
|
# "batch_size": 64,
|
||||||
"window_size": args.window,
|
"window_size": args.window,
|
||||||
"domain_length": args.domain_length,
|
"domain_length": args.domain_length,
|
||||||
"flow_features": 3,
|
"flow_features": 3,
|
||||||
#
|
#
|
||||||
'dropout': 0.5,
|
'dropout': 0.5, # currently fix
|
||||||
'domain_features': args.domain_embedding,
|
'domain_features': args.domain_embedding,
|
||||||
'embedding_size': args.embedding,
|
'embedding_size': args.embedding,
|
||||||
'flow_features': 3,
|
'flow_features': 3,
|
||||||
'filter_embedding': args.hidden_char_dims,
|
'filter_embedding': args.filter_embedding,
|
||||||
'hidden_embedding': args.domain_embedding,
|
'dense_embedding': args.dense_embedding,
|
||||||
'kernel_embedding': 3,
|
'kernel_embedding': args.kernel_embedding,
|
||||||
'filter_main': 128,
|
'filter_main': args.filter_main,
|
||||||
'dense_main': 128,
|
'dense_main': args.dense_main,
|
||||||
'kernels_main': 3,
|
'kernel_main': args.kernel_main,
|
||||||
'input_length': 40,
|
'input_length': 40,
|
||||||
'model_output': args.model_output
|
'model_output': args.model_output
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,7 @@ def get_models_by_params(params: dict):
|
|||||||
input_length = params.get("input_length")
|
input_length = params.get("input_length")
|
||||||
filter_embedding = params.get("filter_embedding")
|
filter_embedding = params.get("filter_embedding")
|
||||||
kernel_embedding = params.get("kernel_embedding")
|
kernel_embedding = params.get("kernel_embedding")
|
||||||
hidden_embedding = params.get("hidden_embedding")
|
hidden_embedding = params.get("dense_embedding")
|
||||||
dropout = params.get("dropout")
|
dropout = params.get("dropout")
|
||||||
# mainly prediction model
|
# mainly prediction model
|
||||||
flow_features = params.get("flow_features")
|
flow_features = params.get("flow_features")
|
||||||
@ -21,7 +21,7 @@ def get_models_by_params(params: dict):
|
|||||||
window_size = params.get("window_size")
|
window_size = params.get("window_size")
|
||||||
domain_length = params.get("domain_length")
|
domain_length = params.get("domain_length")
|
||||||
filter_main = params.get("filter_main")
|
filter_main = params.get("filter_main")
|
||||||
kernel_main = params.get("kernels_main")
|
kernel_main = params.get("kernel_main")
|
||||||
dense_dim = params.get("dense_main")
|
dense_dim = params.get("dense_main")
|
||||||
model_output = params.get("model_output", "both")
|
model_output = params.get("model_output", "both")
|
||||||
# create models
|
# create models
|
||||||
@ -32,12 +32,12 @@ def get_models_by_params(params: dict):
|
|||||||
else:
|
else:
|
||||||
raise Exception("network not found")
|
raise Exception("network not found")
|
||||||
embedding_model = networks.get_embedding(embedding_size, input_length, filter_embedding, kernel_embedding,
|
embedding_model = networks.get_embedding(embedding_size, input_length, filter_embedding, kernel_embedding,
|
||||||
hidden_embedding, dropout)
|
hidden_embedding, 0.5)
|
||||||
|
|
||||||
old_model = networks.get_model(dropout, flow_features, domain_features, window_size, domain_length,
|
old_model = networks.get_model(0.25, flow_features, domain_features, window_size, domain_length,
|
||||||
filter_main, kernel_main, dense_dim, embedding_model, model_output)
|
filter_main, kernel_main, dense_dim, embedding_model, model_output)
|
||||||
|
|
||||||
new_model = networks.get_new_model(dropout, flow_features, domain_features, window_size, domain_length,
|
new_model = networks.get_new_model(0.25, flow_features, domain_features, window_size, domain_length,
|
||||||
filter_main, kernel_main, dense_dim, embedding_model, model_output)
|
filter_main, kernel_main, dense_dim, embedding_model, model_output)
|
||||||
|
|
||||||
return embedding_model, old_model, new_model
|
return embedding_model, old_model, new_model
|
||||||
|
10
run.sh
10
run.sh
@ -16,9 +16,10 @@ do
|
|||||||
--train ${DATADIR}/currentData.csv \
|
--train ${DATADIR}/currentData.csv \
|
||||||
--model ${RESDIR}/${output}_${depth}_${mtype} \
|
--model ${RESDIR}/${output}_${depth}_${mtype} \
|
||||||
--epochs 50 \
|
--epochs 50 \
|
||||||
--embd 64 \
|
--embd 128 \
|
||||||
--hidden_char_dims 128 \
|
--filter_embd 256 --kernel_embd 8 --dense_embd 128 \
|
||||||
--domain_embd 32 \
|
--domain_embd 32 \
|
||||||
|
--filter_main 32 --kernel_main 8 --dense_main 1024 \
|
||||||
--batch 256 \
|
--batch 256 \
|
||||||
--balanced_weights \
|
--balanced_weights \
|
||||||
--model_output ${output} \
|
--model_output ${output} \
|
||||||
@ -35,9 +36,10 @@ do
|
|||||||
--train ${DATADIR}/currentData.csv \
|
--train ${DATADIR}/currentData.csv \
|
||||||
--model ${RESDIR}/both_${depth}_inter \
|
--model ${RESDIR}/both_${depth}_inter \
|
||||||
--epochs 50 \
|
--epochs 50 \
|
||||||
--embd 64 \
|
--embd 128 \
|
||||||
--hidden_char_dims 128 \
|
--filter_embd 256 --kernel_embd 8 --dense_embd 128 \
|
||||||
--domain_embd 32 \
|
--domain_embd 32 \
|
||||||
|
--filter_main 32 --kernel_main 8 --dense_main 1024 \
|
||||||
--batch 256 \
|
--batch 256 \
|
||||||
--balanced_weights \
|
--balanced_weights \
|
||||||
--model_output both \
|
--model_output both \
|
||||||
|
Loading…
x
Reference in New Issue
Block a user