diff --git a/main.py b/main.py index 76da15f..42cbe5c 100644 --- a/main.py +++ b/main.py @@ -66,13 +66,12 @@ PARAMS = { 'dropout': 0.5, 'domain_features': args.domain_embedding, 'embedding_size': args.embedding, - 'filter_main': 64, 'flow_features': 3, - # 'dense_main': 512, - 'dense_main': 64, 'filter_embedding': args.hidden_char_dims, 'hidden_embedding': args.domain_embedding, 'kernel_embedding': 3, + 'filter_main': 128, + 'dense_main': 128, 'kernels_main': 3, 'input_length': 40, 'model_output': args.model_output @@ -154,34 +153,63 @@ def main_train(param=None): custom_class_weights = None logger.info(f"select model: {args.model_type}") - if args.model_type == "inter": + if args.model_type == "staggered": server_tr = np.expand_dims(server_windows_tr, 2) model = new_model - logger.info("compile and train model") - embedding.summary() - model.summary() - logger.info(model.get_config()) - model.compile(optimizer='adam', - loss='binary_crossentropy', - metrics=['accuracy'] + custom_metrics) + logger.info("compile and train model") + embedding.summary() + model.summary() + logger.info(model.get_config()) + + model.outputs + + model.compile(optimizer='adam', + loss='binary_crossentropy', + metrics=['accuracy'] + custom_metrics) + + if args.model_output == "both": + labels = [client_tr, server_tr] + else: + raise ValueError("unknown model output") + + model.fit([domain_tr, flow_tr], + labels, + batch_size=args.batch_size, + epochs=args.epochs, + callbacks=callbacks, + shuffle=True, + validation_split=0.2, + class_weight=custom_class_weights) - if args.model_output == "both": - labels = [client_tr, server_tr] - elif args.model_output == "client": - labels = [client_tr] - elif args.model_output == "server": - labels = [server_tr] else: - raise ValueError("unknown model output") + if args.model_type == "inter": + server_tr = np.expand_dims(server_windows_tr, 2) + model = new_model + logger.info("compile and train model") + embedding.summary() + model.summary() + logger.info(model.get_config()) + model.compile(optimizer='adam', + loss='binary_crossentropy', + metrics=['accuracy'] + custom_metrics) - model.fit([domain_tr, flow_tr], - labels, - batch_size=args.batch_size, - epochs=args.epochs, - callbacks=callbacks, - shuffle=True, - validation_split=0.2, - class_weight=custom_class_weights) + if args.model_output == "both": + labels = [client_tr, server_tr] + elif args.model_output == "client": + labels = [client_tr] + elif args.model_output == "server": + labels = [server_tr] + else: + raise ValueError("unknown model output") + + model.fit([domain_tr, flow_tr], + labels, + batch_size=args.batch_size, + epochs=args.epochs, + callbacks=callbacks, + shuffle=True, + validation_split=0.2, + class_weight=custom_class_weights) logger.info("save embedding") embedding.save(args.embedding_model) @@ -225,9 +253,9 @@ def main_visualization(): # client_val, server_val = client_val.value, server_val.value client_val = client_val.value - # logger.info("plot model") - # model = load_model(model_args.clf_model, custom_objects=models.get_metrics()) - # visualize.plot_model(model, os.path.join(args.model_path, "model.png")) + logger.info("plot model") + model = load_model(args.clf_model, custom_objects=models.get_metrics()) + visualize.plot_model_as(model, os.path.join(args.model_path, "model.png")) try: logger.info("plot training curve") @@ -276,10 +304,10 @@ def main_visualization(): # visualize.plot_confusion_matrix(server_val.argmax(1), server_pred.argmax(1), # "{}/server_cov.png".format(args.model_path), # normalize=False, title="Server Confusion Matrix") - # logger.info("visualize embedding") - # domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) - # domain_embedding = np.load(args.model_path + "/domain_embds.npy") - # visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path)) + logger.info("visualize embedding") + domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) + domain_embedding = np.load(args.model_path + "/domain_embds.npy") + visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path)) def main_visualize_all(): @@ -293,7 +321,7 @@ def main_visualize_all(): client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"]) visualize.plot_legend() - visualize.plot_save("all_client_prc.png") + visualize.plot_save(f"{args.output_prefix}_client_prc.png") logger.info("plot roc curves") visualize.plot_clf() @@ -301,7 +329,7 @@ def main_visualize_all(): client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"]) visualize.plot_legend() - visualize.plot_save("all_client_roc.png") + visualize.plot_save(f"{args.output_prefix}_client_roc.png") df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val}) user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float) @@ -314,7 +342,7 @@ def main_visualize_all(): user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"]) visualize.plot_legend() - visualize.plot_save("all_user_client_prc.png") + visualize.plot_save(f"{args.output_prefix}_user_client_prc.png") logger.info("plot user roc curves") visualize.plot_clf() @@ -324,7 +352,7 @@ def main_visualize_all(): user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"]) visualize.plot_legend() - visualize.plot_save("all_user_client_roc.png") + visualize.plot_save(f"{args.output_prefix}_user_client_roc.png") def main_data(): diff --git a/run.sh b/run.sh index b2b50b7..94cdf88 100644 --- a/run.sh +++ b/run.sh @@ -1,6 +1,10 @@ #!/usr/bin/env bash +RESDIR=$1 +mkdir -p /tmp/rk/RESDIR +DATADIR=$2 + for output in client both do for depth in small medium @@ -9,8 +13,8 @@ do do python main.py --mode train \ - --train /tmp/rk/currentData.csv \ - --model /tmp/rk/results/${output}_${depth}_${mtype} \ + --train ${DATADIR}/currentData.csv \ + --model ${RESDIR}/${output}_${depth}_${mtype} \ --epochs 50 \ --embd 64 \ --hidden_char_dims 128 \ @@ -28,8 +32,8 @@ done for depth in small medium do python main.py --mode train \ - --train /tmp/rk/currentData.csv \ - --model /tmp/rk/results/both_${depth}_inter \ + --train ${DATADIR}/currentData.csv \ + --model ${RESDIR}/both_${depth}_inter \ --epochs 50 \ --embd 64 \ --hidden_char_dims 128 \ diff --git a/test.sh b/test.sh index 378782e..97ae031 100644 --- a/test.sh +++ b/test.sh @@ -1,10 +1,12 @@ #!/usr/bin/env bash +RESDIR=$1 +DATADIR=$2 for output in client both do python3 main.py --mode test --batch 1024 \ - --models tm/rk/${output}_* \ - --test data/futureData.csv \ + --models ${RESDIR}/${output}_*/ \ + --test ${DATADIR}/futureData.csv \ --model_output ${output} done