remove identical parameter from data loading function; add runs argument
This commit is contained in:
parent
826357a41f
commit
903e81c931
12
Makefile
12
Makefile
@ -1,27 +1,27 @@
|
||||
run:
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client --epochs 2 --depth flat1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client --runs 1
|
||||
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_final --epochs 2 --depth flat1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output both
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output both --runs 1
|
||||
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_inter --epochs 2 --depth flat1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output both
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output both --runs 1
|
||||
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_soft --epochs 2 --depth flat1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type soft --model_output both
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type soft --model_output both --runs 1
|
||||
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_long --epochs 2 --depth flat1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type long --model_output both
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type long --model_output both --runs 1
|
||||
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_staggered --epochs 2 --depth flat1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type staggered --model_output both
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type staggered --model_output both --runs 1
|
||||
|
||||
|
||||
test:
|
||||
|
@ -49,6 +49,10 @@ parser.add_argument("--epochs", action="store", dest="epochs",
|
||||
parser.add_argument("--init_epoch", action="store", dest="initial_epoch",
|
||||
default=0, type=int)
|
||||
|
||||
parser.add_argument("--runs", action="store", dest="runs",
|
||||
default=20, type=int)
|
||||
|
||||
|
||||
# parser.add_argument("--samples", action="store", dest="samples",
|
||||
# default=100000, type=int)
|
||||
#
|
||||
|
18
dataset.py
18
dataset.py
@ -193,14 +193,14 @@ def get_flow_per_user(df):
|
||||
yield df.loc[df.user_hash == user].dropna(axis=0, how="any")
|
||||
|
||||
|
||||
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
||||
logger.info(f"check for h5data {h5data}")
|
||||
def load_or_generate_h5data(train_data, domain_length, window_size):
|
||||
logger.info(f"check for h5data {train_data}")
|
||||
try:
|
||||
check_h5dataset(h5data)
|
||||
check_h5dataset(train_data)
|
||||
except FileNotFoundError:
|
||||
logger.info("load raw training dataset")
|
||||
domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data, train_data,
|
||||
domain_length, window_size)
|
||||
domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(train_data, domain_length,
|
||||
window_size)
|
||||
logger.info("filter training dataset")
|
||||
domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value,
|
||||
name.value, hits.value,
|
||||
@ -213,14 +213,14 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
||||
"client": client.astype(np.bool),
|
||||
"server": server.astype(np.bool)
|
||||
}
|
||||
store_h5dataset(h5data, data)
|
||||
store_h5dataset(train_data, data)
|
||||
logger.info("load h5 dataset")
|
||||
data = load_h5dataset(h5data)
|
||||
data = load_h5dataset(train_data)
|
||||
return data["domain"], data["flow"], data["name"], data["client"], data["server"]
|
||||
|
||||
|
||||
def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size):
|
||||
h5data = h5data + "_raw"
|
||||
def load_or_generate_raw_h5data(train_data, domain_length, window_size):
|
||||
h5data = train_data + "_raw"
|
||||
logger.info(f"check for h5data {h5data}")
|
||||
try:
|
||||
check_h5dataset(h5data)
|
||||
|
27
main.py
27
main.py
@ -59,7 +59,7 @@ if args.gpu:
|
||||
# default parameter
|
||||
PARAMS = {
|
||||
"type": args.model_type,
|
||||
"depth": args.model_depth,
|
||||
# "depth": args.model_depth,
|
||||
"batch_size": args.batch_size,
|
||||
"window_size": args.window,
|
||||
"domain_length": args.domain_length,
|
||||
@ -84,7 +84,7 @@ def get_param_dist(dist_size="small"):
|
||||
return {
|
||||
# static params
|
||||
"type": [args.model_type],
|
||||
"depth": [args.model_depth],
|
||||
# "depth": [args.model_depth],
|
||||
"model_output": [args.model_output],
|
||||
"batch_size": [args.batch_size],
|
||||
"window_size": [args.window],
|
||||
@ -104,7 +104,7 @@ def get_param_dist(dist_size="small"):
|
||||
return {
|
||||
# static params
|
||||
"type": [args.model_type],
|
||||
"depth": [args.model_depth],
|
||||
# "depth": [args.model_depth],
|
||||
"model_output": [args.model_output],
|
||||
"batch_size": [args.batch_size],
|
||||
"window_size": [args.window],
|
||||
@ -159,9 +159,7 @@ def main_hyperband(data, domain_length, window_size, model_type, result_file, di
|
||||
param_dist = get_param_dist(dist_size)
|
||||
|
||||
logger.info("create training dataset")
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data,
|
||||
data,
|
||||
domain_length,
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data, domain_length,
|
||||
window)
|
||||
server_tr = np.max(server_windows_tr, axis=1)
|
||||
|
||||
@ -191,9 +189,7 @@ def train(parameters, features, labels):
|
||||
|
||||
def load_data(data, domain_length, window_size, model_type):
|
||||
# data preparation
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data,
|
||||
data,
|
||||
domain_length,
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data, domain_length,
|
||||
window_size)
|
||||
server_tr = np.max(server_windows_tr, axis=1)
|
||||
if model_type in ("inter", "staggered"):
|
||||
@ -219,9 +215,9 @@ def main_train(param=None):
|
||||
if not param:
|
||||
param = PARAMS
|
||||
|
||||
for i in range(20):
|
||||
for i in range(args.runs):
|
||||
model_path = os.path.join(args.model_path, f"clf_{i}.h5")
|
||||
train_log_path = os.path.join(args.model_path, "train_{i}.log.csv")
|
||||
train_log_path = os.path.join(args.model_path, f"train_{i}.log.csv")
|
||||
# define training call backs
|
||||
logger.info("define callbacks")
|
||||
callbacks = []
|
||||
@ -318,7 +314,6 @@ def main_retrain():
|
||||
exists_or_make_path(args.model_destination)
|
||||
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
|
||||
args.data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
logger.info("define callbacks")
|
||||
@ -373,10 +368,7 @@ def main_retrain():
|
||||
|
||||
def main_test():
|
||||
logger.info("start test: load data")
|
||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
|
||||
args.data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
|
||||
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
||||
|
||||
for model_args in get_model_args(args):
|
||||
@ -425,7 +417,6 @@ def main_visualization():
|
||||
visualize.plot_save("{}/{}_{}.pdf".format(model_path, aggregation, curve))
|
||||
|
||||
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
||||
args.data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
|
||||
@ -484,7 +475,6 @@ def main_visualization():
|
||||
|
||||
def main_visualize_all():
|
||||
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
||||
args.data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
|
||||
@ -600,7 +590,6 @@ def main_visualize_all_embds():
|
||||
|
||||
def main_beta():
|
||||
domain_val, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
||||
args.data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
||||
|
@ -17,7 +17,7 @@ def get_models_by_params(params: dict):
|
||||
# decomposing param section
|
||||
# mainly embedding model
|
||||
network_type = params.get("type")
|
||||
network_depth = params.get("depth")
|
||||
# network_depth = params.get("depth")
|
||||
embedding_size = params.get("embedding")
|
||||
filter_embedding = params.get("filter_embedding")
|
||||
kernel_embedding = params.get("kernel_embedding")
|
||||
|
11
server.py
11
server.py
@ -21,7 +21,6 @@ def train_server_only(params):
|
||||
logger.info(f"Use command line arguments: {args}")
|
||||
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
|
||||
args.data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
domain_tr = domain_tr.value.reshape(-1, 40)
|
||||
@ -69,10 +68,7 @@ def train_server_only(params):
|
||||
|
||||
def test_server_only():
|
||||
logger.info("start test: load data")
|
||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
|
||||
args.data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
|
||||
domain_val = domain_val.value.reshape(-1, 40)
|
||||
flow_val = flow_val.value.reshape(-1, 3)
|
||||
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
||||
@ -102,10 +98,7 @@ def vis_server():
|
||||
return emdb, clf
|
||||
|
||||
domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data(
|
||||
args.data,
|
||||
args.data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
args.data, args.domain_length, args.window)
|
||||
|
||||
results = dataset.load_predictions(args.clf_model)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user