add parser argument for naming in multi model modes, minor fixes, re-set fix vals for network - need to make them flexible
This commit is contained in:
parent
ed4f478bad
commit
2080444fb7
102
main.py
102
main.py
@ -66,13 +66,12 @@ PARAMS = {
|
|||||||
'dropout': 0.5,
|
'dropout': 0.5,
|
||||||
'domain_features': args.domain_embedding,
|
'domain_features': args.domain_embedding,
|
||||||
'embedding_size': args.embedding,
|
'embedding_size': args.embedding,
|
||||||
'filter_main': 64,
|
|
||||||
'flow_features': 3,
|
'flow_features': 3,
|
||||||
# 'dense_main': 512,
|
|
||||||
'dense_main': 64,
|
|
||||||
'filter_embedding': args.hidden_char_dims,
|
'filter_embedding': args.hidden_char_dims,
|
||||||
'hidden_embedding': args.domain_embedding,
|
'hidden_embedding': args.domain_embedding,
|
||||||
'kernel_embedding': 3,
|
'kernel_embedding': 3,
|
||||||
|
'filter_main': 128,
|
||||||
|
'dense_main': 128,
|
||||||
'kernels_main': 3,
|
'kernels_main': 3,
|
||||||
'input_length': 40,
|
'input_length': 40,
|
||||||
'model_output': args.model_output
|
'model_output': args.model_output
|
||||||
@ -154,34 +153,63 @@ def main_train(param=None):
|
|||||||
custom_class_weights = None
|
custom_class_weights = None
|
||||||
|
|
||||||
logger.info(f"select model: {args.model_type}")
|
logger.info(f"select model: {args.model_type}")
|
||||||
if args.model_type == "inter":
|
if args.model_type == "staggered":
|
||||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||||
model = new_model
|
model = new_model
|
||||||
logger.info("compile and train model")
|
logger.info("compile and train model")
|
||||||
embedding.summary()
|
embedding.summary()
|
||||||
model.summary()
|
model.summary()
|
||||||
logger.info(model.get_config())
|
logger.info(model.get_config())
|
||||||
model.compile(optimizer='adam',
|
|
||||||
loss='binary_crossentropy',
|
model.outputs
|
||||||
metrics=['accuracy'] + custom_metrics)
|
|
||||||
|
model.compile(optimizer='adam',
|
||||||
|
loss='binary_crossentropy',
|
||||||
|
metrics=['accuracy'] + custom_metrics)
|
||||||
|
|
||||||
|
if args.model_output == "both":
|
||||||
|
labels = [client_tr, server_tr]
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown model output")
|
||||||
|
|
||||||
|
model.fit([domain_tr, flow_tr],
|
||||||
|
labels,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
epochs=args.epochs,
|
||||||
|
callbacks=callbacks,
|
||||||
|
shuffle=True,
|
||||||
|
validation_split=0.2,
|
||||||
|
class_weight=custom_class_weights)
|
||||||
|
|
||||||
if args.model_output == "both":
|
|
||||||
labels = [client_tr, server_tr]
|
|
||||||
elif args.model_output == "client":
|
|
||||||
labels = [client_tr]
|
|
||||||
elif args.model_output == "server":
|
|
||||||
labels = [server_tr]
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("unknown model output")
|
if args.model_type == "inter":
|
||||||
|
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||||
|
model = new_model
|
||||||
|
logger.info("compile and train model")
|
||||||
|
embedding.summary()
|
||||||
|
model.summary()
|
||||||
|
logger.info(model.get_config())
|
||||||
|
model.compile(optimizer='adam',
|
||||||
|
loss='binary_crossentropy',
|
||||||
|
metrics=['accuracy'] + custom_metrics)
|
||||||
|
|
||||||
model.fit([domain_tr, flow_tr],
|
if args.model_output == "both":
|
||||||
labels,
|
labels = [client_tr, server_tr]
|
||||||
batch_size=args.batch_size,
|
elif args.model_output == "client":
|
||||||
epochs=args.epochs,
|
labels = [client_tr]
|
||||||
callbacks=callbacks,
|
elif args.model_output == "server":
|
||||||
shuffle=True,
|
labels = [server_tr]
|
||||||
validation_split=0.2,
|
else:
|
||||||
class_weight=custom_class_weights)
|
raise ValueError("unknown model output")
|
||||||
|
|
||||||
|
model.fit([domain_tr, flow_tr],
|
||||||
|
labels,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
epochs=args.epochs,
|
||||||
|
callbacks=callbacks,
|
||||||
|
shuffle=True,
|
||||||
|
validation_split=0.2,
|
||||||
|
class_weight=custom_class_weights)
|
||||||
logger.info("save embedding")
|
logger.info("save embedding")
|
||||||
embedding.save(args.embedding_model)
|
embedding.save(args.embedding_model)
|
||||||
|
|
||||||
@ -225,9 +253,9 @@ def main_visualization():
|
|||||||
# client_val, server_val = client_val.value, server_val.value
|
# client_val, server_val = client_val.value, server_val.value
|
||||||
client_val = client_val.value
|
client_val = client_val.value
|
||||||
|
|
||||||
# logger.info("plot model")
|
logger.info("plot model")
|
||||||
# model = load_model(model_args.clf_model, custom_objects=models.get_metrics())
|
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||||
# visualize.plot_model(model, os.path.join(args.model_path, "model.png"))
|
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("plot training curve")
|
logger.info("plot training curve")
|
||||||
@ -276,10 +304,10 @@ def main_visualization():
|
|||||||
# visualize.plot_confusion_matrix(server_val.argmax(1), server_pred.argmax(1),
|
# visualize.plot_confusion_matrix(server_val.argmax(1), server_pred.argmax(1),
|
||||||
# "{}/server_cov.png".format(args.model_path),
|
# "{}/server_cov.png".format(args.model_path),
|
||||||
# normalize=False, title="Server Confusion Matrix")
|
# normalize=False, title="Server Confusion Matrix")
|
||||||
# logger.info("visualize embedding")
|
logger.info("visualize embedding")
|
||||||
# domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||||
# domain_embedding = np.load(args.model_path + "/domain_embds.npy")
|
domain_embedding = np.load(args.model_path + "/domain_embds.npy")
|
||||||
# visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
|
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
|
||||||
|
|
||||||
|
|
||||||
def main_visualize_all():
|
def main_visualize_all():
|
||||||
@ -293,7 +321,7 @@ def main_visualize_all():
|
|||||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||||
visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"])
|
visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"])
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("all_client_prc.png")
|
visualize.plot_save(f"{args.output_prefix}_client_prc.png")
|
||||||
|
|
||||||
logger.info("plot roc curves")
|
logger.info("plot roc curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
@ -301,7 +329,7 @@ def main_visualize_all():
|
|||||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||||
visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"])
|
visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"])
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("all_client_roc.png")
|
visualize.plot_save(f"{args.output_prefix}_client_roc.png")
|
||||||
|
|
||||||
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
||||||
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
|
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
|
||||||
@ -314,7 +342,7 @@ def main_visualize_all():
|
|||||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
||||||
visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"])
|
visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"])
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("all_user_client_prc.png")
|
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
|
||||||
|
|
||||||
logger.info("plot user roc curves")
|
logger.info("plot user roc curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
@ -324,7 +352,7 @@ def main_visualize_all():
|
|||||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
||||||
visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"])
|
visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"])
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("all_user_client_roc.png")
|
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
|
||||||
|
|
||||||
|
|
||||||
def main_data():
|
def main_data():
|
||||||
|
12
run.sh
12
run.sh
@ -1,6 +1,10 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
|
||||||
|
RESDIR=$1
|
||||||
|
mkdir -p /tmp/rk/RESDIR
|
||||||
|
DATADIR=$2
|
||||||
|
|
||||||
for output in client both
|
for output in client both
|
||||||
do
|
do
|
||||||
for depth in small medium
|
for depth in small medium
|
||||||
@ -9,8 +13,8 @@ do
|
|||||||
do
|
do
|
||||||
|
|
||||||
python main.py --mode train \
|
python main.py --mode train \
|
||||||
--train /tmp/rk/currentData.csv \
|
--train ${DATADIR}/currentData.csv \
|
||||||
--model /tmp/rk/results/${output}_${depth}_${mtype} \
|
--model ${RESDIR}/${output}_${depth}_${mtype} \
|
||||||
--epochs 50 \
|
--epochs 50 \
|
||||||
--embd 64 \
|
--embd 64 \
|
||||||
--hidden_char_dims 128 \
|
--hidden_char_dims 128 \
|
||||||
@ -28,8 +32,8 @@ done
|
|||||||
for depth in small medium
|
for depth in small medium
|
||||||
do
|
do
|
||||||
python main.py --mode train \
|
python main.py --mode train \
|
||||||
--train /tmp/rk/currentData.csv \
|
--train ${DATADIR}/currentData.csv \
|
||||||
--model /tmp/rk/results/both_${depth}_inter \
|
--model ${RESDIR}/both_${depth}_inter \
|
||||||
--epochs 50 \
|
--epochs 50 \
|
||||||
--embd 64 \
|
--embd 64 \
|
||||||
--hidden_char_dims 128 \
|
--hidden_char_dims 128 \
|
||||||
|
6
test.sh
6
test.sh
@ -1,10 +1,12 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
RESDIR=$1
|
||||||
|
DATADIR=$2
|
||||||
|
|
||||||
for output in client both
|
for output in client both
|
||||||
do
|
do
|
||||||
python3 main.py --mode test --batch 1024 \
|
python3 main.py --mode test --batch 1024 \
|
||||||
--models tm/rk/${output}_* \
|
--models ${RESDIR}/${output}_*/ \
|
||||||
--test data/futureData.csv \
|
--test ${DATADIR}/futureData.csv \
|
||||||
--model_output ${output}
|
--model_output ${output}
|
||||||
done
|
done
|
||||||
|
Loading…
Reference in New Issue
Block a user