remove identical parameter from data loading function; add runs argument

This commit is contained in:
René Knaebel 2017-11-07 20:47:41 +01:00
parent 826357a41f
commit 903e81c931
6 changed files with 31 additions and 45 deletions

View File

@ -1,27 +1,27 @@
run:
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client --epochs 2 --depth flat1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client --runs 1
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_final --epochs 2 --depth flat1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output both
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output both --runs 1
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_inter --epochs 2 --depth flat1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output both
--dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output both --runs 1
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_soft --epochs 2 --depth flat1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --type soft --model_output both
--dense_embd 16 --domain_embd 8 --batch 64 --type soft --model_output both --runs 1
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_long --epochs 2 --depth flat1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --type long --model_output both
--dense_embd 16 --domain_embd 8 --batch 64 --type long --model_output both --runs 1
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_staggered --epochs 2 --depth flat1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --type staggered --model_output both
--dense_embd 16 --domain_embd 8 --batch 64 --type staggered --model_output both --runs 1
test:

View File

@ -49,6 +49,10 @@ parser.add_argument("--epochs", action="store", dest="epochs",
parser.add_argument("--init_epoch", action="store", dest="initial_epoch",
default=0, type=int)
parser.add_argument("--runs", action="store", dest="runs",
default=20, type=int)
# parser.add_argument("--samples", action="store", dest="samples",
# default=100000, type=int)
#

View File

@ -193,14 +193,14 @@ def get_flow_per_user(df):
yield df.loc[df.user_hash == user].dropna(axis=0, how="any")
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
logger.info(f"check for h5data {h5data}")
def load_or_generate_h5data(train_data, domain_length, window_size):
logger.info(f"check for h5data {train_data}")
try:
check_h5dataset(h5data)
check_h5dataset(train_data)
except FileNotFoundError:
logger.info("load raw training dataset")
domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data, train_data,
domain_length, window_size)
domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(train_data, domain_length,
window_size)
logger.info("filter training dataset")
domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value,
name.value, hits.value,
@ -213,14 +213,14 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
"client": client.astype(np.bool),
"server": server.astype(np.bool)
}
store_h5dataset(h5data, data)
store_h5dataset(train_data, data)
logger.info("load h5 dataset")
data = load_h5dataset(h5data)
data = load_h5dataset(train_data)
return data["domain"], data["flow"], data["name"], data["client"], data["server"]
def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size):
h5data = h5data + "_raw"
def load_or_generate_raw_h5data(train_data, domain_length, window_size):
h5data = train_data + "_raw"
logger.info(f"check for h5data {h5data}")
try:
check_h5dataset(h5data)

27
main.py
View File

@ -59,7 +59,7 @@ if args.gpu:
# default parameter
PARAMS = {
"type": args.model_type,
"depth": args.model_depth,
# "depth": args.model_depth,
"batch_size": args.batch_size,
"window_size": args.window,
"domain_length": args.domain_length,
@ -84,7 +84,7 @@ def get_param_dist(dist_size="small"):
return {
# static params
"type": [args.model_type],
"depth": [args.model_depth],
# "depth": [args.model_depth],
"model_output": [args.model_output],
"batch_size": [args.batch_size],
"window_size": [args.window],
@ -104,7 +104,7 @@ def get_param_dist(dist_size="small"):
return {
# static params
"type": [args.model_type],
"depth": [args.model_depth],
# "depth": [args.model_depth],
"model_output": [args.model_output],
"batch_size": [args.batch_size],
"window_size": [args.window],
@ -159,9 +159,7 @@ def main_hyperband(data, domain_length, window_size, model_type, result_file, di
param_dist = get_param_dist(dist_size)
logger.info("create training dataset")
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data,
data,
domain_length,
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data, domain_length,
window)
server_tr = np.max(server_windows_tr, axis=1)
@ -191,9 +189,7 @@ def train(parameters, features, labels):
def load_data(data, domain_length, window_size, model_type):
# data preparation
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data,
data,
domain_length,
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data, domain_length,
window_size)
server_tr = np.max(server_windows_tr, axis=1)
if model_type in ("inter", "staggered"):
@ -219,9 +215,9 @@ def main_train(param=None):
if not param:
param = PARAMS
for i in range(20):
for i in range(args.runs):
model_path = os.path.join(args.model_path, f"clf_{i}.h5")
train_log_path = os.path.join(args.model_path, "train_{i}.log.csv")
train_log_path = os.path.join(args.model_path, f"train_{i}.log.csv")
# define training call backs
logger.info("define callbacks")
callbacks = []
@ -318,7 +314,6 @@ def main_retrain():
exists_or_make_path(args.model_destination)
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
args.data,
args.domain_length,
args.window)
logger.info("define callbacks")
@ -373,10 +368,7 @@ def main_retrain():
def main_test():
logger.info("start test: load data")
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
args.data,
args.domain_length,
args.window)
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
for model_args in get_model_args(args):
@ -425,7 +417,6 @@ def main_visualization():
visualize.plot_save("{}/{}_{}.pdf".format(model_path, aggregation, curve))
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.data,
args.domain_length,
args.window)
@ -484,7 +475,6 @@ def main_visualization():
def main_visualize_all():
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.data,
args.domain_length,
args.window)
@ -600,7 +590,6 @@ def main_visualize_all_embds():
def main_beta():
domain_val, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.data,
args.domain_length,
args.window)
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))

View File

@ -17,7 +17,7 @@ def get_models_by_params(params: dict):
# decomposing param section
# mainly embedding model
network_type = params.get("type")
network_depth = params.get("depth")
# network_depth = params.get("depth")
embedding_size = params.get("embedding")
filter_embedding = params.get("filter_embedding")
kernel_embedding = params.get("kernel_embedding")

View File

@ -21,7 +21,6 @@ def train_server_only(params):
logger.info(f"Use command line arguments: {args}")
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
args.data,
args.domain_length,
args.window)
domain_tr = domain_tr.value.reshape(-1, 40)
@ -69,10 +68,7 @@ def train_server_only(params):
def test_server_only():
logger.info("start test: load data")
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
args.data,
args.domain_length,
args.window)
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
domain_val = domain_val.value.reshape(-1, 40)
flow_val = flow_val.value.reshape(-1, 3)
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
@ -102,10 +98,7 @@ def vis_server():
return emdb, clf
domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data(
args.data,
args.data,
args.domain_length,
args.window)
args.data, args.domain_length, args.window)
results = dataset.load_predictions(args.clf_model)