remove identical parameter from data loading function; add runs argument
This commit is contained in:
parent
826357a41f
commit
903e81c931
12
Makefile
12
Makefile
@ -1,27 +1,27 @@
|
|||||||
run:
|
run:
|
||||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client --epochs 2 --depth flat1 \
|
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client --epochs 2 --depth flat1 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client
|
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client --runs 1
|
||||||
|
|
||||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_final --epochs 2 --depth flat1 \
|
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_final --epochs 2 --depth flat1 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output both
|
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output both --runs 1
|
||||||
|
|
||||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_inter --epochs 2 --depth flat1 \
|
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_inter --epochs 2 --depth flat1 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output both
|
--dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output both --runs 1
|
||||||
|
|
||||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_soft --epochs 2 --depth flat1 \
|
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_soft --epochs 2 --depth flat1 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --type soft --model_output both
|
--dense_embd 16 --domain_embd 8 --batch 64 --type soft --model_output both --runs 1
|
||||||
|
|
||||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_long --epochs 2 --depth flat1 \
|
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_long --epochs 2 --depth flat1 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --type long --model_output both
|
--dense_embd 16 --domain_embd 8 --batch 64 --type long --model_output both --runs 1
|
||||||
|
|
||||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_staggered --epochs 2 --depth flat1 \
|
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_staggered --epochs 2 --depth flat1 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --type staggered --model_output both
|
--dense_embd 16 --domain_embd 8 --batch 64 --type staggered --model_output both --runs 1
|
||||||
|
|
||||||
|
|
||||||
test:
|
test:
|
||||||
|
@ -49,6 +49,10 @@ parser.add_argument("--epochs", action="store", dest="epochs",
|
|||||||
parser.add_argument("--init_epoch", action="store", dest="initial_epoch",
|
parser.add_argument("--init_epoch", action="store", dest="initial_epoch",
|
||||||
default=0, type=int)
|
default=0, type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--runs", action="store", dest="runs",
|
||||||
|
default=20, type=int)
|
||||||
|
|
||||||
|
|
||||||
# parser.add_argument("--samples", action="store", dest="samples",
|
# parser.add_argument("--samples", action="store", dest="samples",
|
||||||
# default=100000, type=int)
|
# default=100000, type=int)
|
||||||
#
|
#
|
||||||
|
18
dataset.py
18
dataset.py
@ -193,14 +193,14 @@ def get_flow_per_user(df):
|
|||||||
yield df.loc[df.user_hash == user].dropna(axis=0, how="any")
|
yield df.loc[df.user_hash == user].dropna(axis=0, how="any")
|
||||||
|
|
||||||
|
|
||||||
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
def load_or_generate_h5data(train_data, domain_length, window_size):
|
||||||
logger.info(f"check for h5data {h5data}")
|
logger.info(f"check for h5data {train_data}")
|
||||||
try:
|
try:
|
||||||
check_h5dataset(h5data)
|
check_h5dataset(train_data)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.info("load raw training dataset")
|
logger.info("load raw training dataset")
|
||||||
domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data, train_data,
|
domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(train_data, domain_length,
|
||||||
domain_length, window_size)
|
window_size)
|
||||||
logger.info("filter training dataset")
|
logger.info("filter training dataset")
|
||||||
domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value,
|
domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value,
|
||||||
name.value, hits.value,
|
name.value, hits.value,
|
||||||
@ -213,14 +213,14 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
|||||||
"client": client.astype(np.bool),
|
"client": client.astype(np.bool),
|
||||||
"server": server.astype(np.bool)
|
"server": server.astype(np.bool)
|
||||||
}
|
}
|
||||||
store_h5dataset(h5data, data)
|
store_h5dataset(train_data, data)
|
||||||
logger.info("load h5 dataset")
|
logger.info("load h5 dataset")
|
||||||
data = load_h5dataset(h5data)
|
data = load_h5dataset(train_data)
|
||||||
return data["domain"], data["flow"], data["name"], data["client"], data["server"]
|
return data["domain"], data["flow"], data["name"], data["client"], data["server"]
|
||||||
|
|
||||||
|
|
||||||
def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size):
|
def load_or_generate_raw_h5data(train_data, domain_length, window_size):
|
||||||
h5data = h5data + "_raw"
|
h5data = train_data + "_raw"
|
||||||
logger.info(f"check for h5data {h5data}")
|
logger.info(f"check for h5data {h5data}")
|
||||||
try:
|
try:
|
||||||
check_h5dataset(h5data)
|
check_h5dataset(h5data)
|
||||||
|
27
main.py
27
main.py
@ -59,7 +59,7 @@ if args.gpu:
|
|||||||
# default parameter
|
# default parameter
|
||||||
PARAMS = {
|
PARAMS = {
|
||||||
"type": args.model_type,
|
"type": args.model_type,
|
||||||
"depth": args.model_depth,
|
# "depth": args.model_depth,
|
||||||
"batch_size": args.batch_size,
|
"batch_size": args.batch_size,
|
||||||
"window_size": args.window,
|
"window_size": args.window,
|
||||||
"domain_length": args.domain_length,
|
"domain_length": args.domain_length,
|
||||||
@ -84,7 +84,7 @@ def get_param_dist(dist_size="small"):
|
|||||||
return {
|
return {
|
||||||
# static params
|
# static params
|
||||||
"type": [args.model_type],
|
"type": [args.model_type],
|
||||||
"depth": [args.model_depth],
|
# "depth": [args.model_depth],
|
||||||
"model_output": [args.model_output],
|
"model_output": [args.model_output],
|
||||||
"batch_size": [args.batch_size],
|
"batch_size": [args.batch_size],
|
||||||
"window_size": [args.window],
|
"window_size": [args.window],
|
||||||
@ -104,7 +104,7 @@ def get_param_dist(dist_size="small"):
|
|||||||
return {
|
return {
|
||||||
# static params
|
# static params
|
||||||
"type": [args.model_type],
|
"type": [args.model_type],
|
||||||
"depth": [args.model_depth],
|
# "depth": [args.model_depth],
|
||||||
"model_output": [args.model_output],
|
"model_output": [args.model_output],
|
||||||
"batch_size": [args.batch_size],
|
"batch_size": [args.batch_size],
|
||||||
"window_size": [args.window],
|
"window_size": [args.window],
|
||||||
@ -159,9 +159,7 @@ def main_hyperband(data, domain_length, window_size, model_type, result_file, di
|
|||||||
param_dist = get_param_dist(dist_size)
|
param_dist = get_param_dist(dist_size)
|
||||||
|
|
||||||
logger.info("create training dataset")
|
logger.info("create training dataset")
|
||||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data,
|
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data, domain_length,
|
||||||
data,
|
|
||||||
domain_length,
|
|
||||||
window)
|
window)
|
||||||
server_tr = np.max(server_windows_tr, axis=1)
|
server_tr = np.max(server_windows_tr, axis=1)
|
||||||
|
|
||||||
@ -191,9 +189,7 @@ def train(parameters, features, labels):
|
|||||||
|
|
||||||
def load_data(data, domain_length, window_size, model_type):
|
def load_data(data, domain_length, window_size, model_type):
|
||||||
# data preparation
|
# data preparation
|
||||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data,
|
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data, domain_length,
|
||||||
data,
|
|
||||||
domain_length,
|
|
||||||
window_size)
|
window_size)
|
||||||
server_tr = np.max(server_windows_tr, axis=1)
|
server_tr = np.max(server_windows_tr, axis=1)
|
||||||
if model_type in ("inter", "staggered"):
|
if model_type in ("inter", "staggered"):
|
||||||
@ -219,9 +215,9 @@ def main_train(param=None):
|
|||||||
if not param:
|
if not param:
|
||||||
param = PARAMS
|
param = PARAMS
|
||||||
|
|
||||||
for i in range(20):
|
for i in range(args.runs):
|
||||||
model_path = os.path.join(args.model_path, f"clf_{i}.h5")
|
model_path = os.path.join(args.model_path, f"clf_{i}.h5")
|
||||||
train_log_path = os.path.join(args.model_path, "train_{i}.log.csv")
|
train_log_path = os.path.join(args.model_path, f"train_{i}.log.csv")
|
||||||
# define training call backs
|
# define training call backs
|
||||||
logger.info("define callbacks")
|
logger.info("define callbacks")
|
||||||
callbacks = []
|
callbacks = []
|
||||||
@ -318,7 +314,6 @@ def main_retrain():
|
|||||||
exists_or_make_path(args.model_destination)
|
exists_or_make_path(args.model_destination)
|
||||||
|
|
||||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
|
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
|
||||||
args.data,
|
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
logger.info("define callbacks")
|
logger.info("define callbacks")
|
||||||
@ -373,10 +368,7 @@ def main_retrain():
|
|||||||
|
|
||||||
def main_test():
|
def main_test():
|
||||||
logger.info("start test: load data")
|
logger.info("start test: load data")
|
||||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
|
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
|
||||||
args.data,
|
|
||||||
args.domain_length,
|
|
||||||
args.window)
|
|
||||||
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
||||||
|
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
@ -425,7 +417,6 @@ def main_visualization():
|
|||||||
visualize.plot_save("{}/{}_{}.pdf".format(model_path, aggregation, curve))
|
visualize.plot_save("{}/{}_{}.pdf".format(model_path, aggregation, curve))
|
||||||
|
|
||||||
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
||||||
args.data,
|
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
|
|
||||||
@ -484,7 +475,6 @@ def main_visualization():
|
|||||||
|
|
||||||
def main_visualize_all():
|
def main_visualize_all():
|
||||||
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
||||||
args.data,
|
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
|
|
||||||
@ -600,7 +590,6 @@ def main_visualize_all_embds():
|
|||||||
|
|
||||||
def main_beta():
|
def main_beta():
|
||||||
domain_val, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
domain_val, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
||||||
args.data,
|
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
||||||
|
@ -17,7 +17,7 @@ def get_models_by_params(params: dict):
|
|||||||
# decomposing param section
|
# decomposing param section
|
||||||
# mainly embedding model
|
# mainly embedding model
|
||||||
network_type = params.get("type")
|
network_type = params.get("type")
|
||||||
network_depth = params.get("depth")
|
# network_depth = params.get("depth")
|
||||||
embedding_size = params.get("embedding")
|
embedding_size = params.get("embedding")
|
||||||
filter_embedding = params.get("filter_embedding")
|
filter_embedding = params.get("filter_embedding")
|
||||||
kernel_embedding = params.get("kernel_embedding")
|
kernel_embedding = params.get("kernel_embedding")
|
||||||
|
11
server.py
11
server.py
@ -21,7 +21,6 @@ def train_server_only(params):
|
|||||||
logger.info(f"Use command line arguments: {args}")
|
logger.info(f"Use command line arguments: {args}")
|
||||||
|
|
||||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
|
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
|
||||||
args.data,
|
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
domain_tr = domain_tr.value.reshape(-1, 40)
|
domain_tr = domain_tr.value.reshape(-1, 40)
|
||||||
@ -69,10 +68,7 @@ def train_server_only(params):
|
|||||||
|
|
||||||
def test_server_only():
|
def test_server_only():
|
||||||
logger.info("start test: load data")
|
logger.info("start test: load data")
|
||||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
|
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
|
||||||
args.data,
|
|
||||||
args.domain_length,
|
|
||||||
args.window)
|
|
||||||
domain_val = domain_val.value.reshape(-1, 40)
|
domain_val = domain_val.value.reshape(-1, 40)
|
||||||
flow_val = flow_val.value.reshape(-1, 3)
|
flow_val = flow_val.value.reshape(-1, 3)
|
||||||
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
||||||
@ -102,10 +98,7 @@ def vis_server():
|
|||||||
return emdb, clf
|
return emdb, clf
|
||||||
|
|
||||||
domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data(
|
domain_raw, flow_raw, name_raw, hits_vt_raw, hits_trusted_raw, server_raw = dataset.load_or_generate_raw_h5data(
|
||||||
args.data,
|
args.data, args.domain_length, args.window)
|
||||||
args.data,
|
|
||||||
args.domain_length,
|
|
||||||
args.window)
|
|
||||||
|
|
||||||
results = dataset.load_predictions(args.clf_model)
|
results = dataset.load_predictions(args.clf_model)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user