diff --git a/arguments.py b/arguments.py index d392d22..b32432a 100644 --- a/arguments.py +++ b/arguments.py @@ -61,6 +61,7 @@ parser.add_argument("--domain_embd", action="store", dest="domain_embedding", # parser.add_argument("--tmp", action="store_true", dest="tmp") # parser.add_argument("--stop_early", action="store_true", dest="stop_early") +parser.add_argument("--balanced_weights", action="store_true", dest="class_weights") parser.add_argument("--gpu", action="store_true", dest="gpu") diff --git a/main.py b/main.py index 55f1a5c..94368a8 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ import pandas as pd import tensorflow as tf from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping from keras.models import load_model +from sklearn.utils import class_weight import arguments import dataset @@ -101,6 +102,17 @@ def main_hyperband(): json.dump(results, open("hyperband.json")) +def get_custom_class_weights(client_tr, server_tr): + client = client_tr.value.argmax(1) + server = server_tr.value.argmax(1) + client_class_weight = class_weight.compute_class_weight('balanced', np.unique(client), client) + server_class_weight = class_weight.compute_class_weight('balanced', np.unique(server), server) + return { + "client": client_class_weight, + "server": server_class_weight + } + + def main_train(param=None): exists_or_make_path(args.model_path) @@ -151,6 +163,14 @@ def main_train(param=None): model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'] + custom_metrics) + + if args.class_weights: + logger.info("class weights: compute custom weights") + custom_class_weights = get_custom_class_weights(client_tr, server_tr) + else: + logger.info("class weights: set default") + custom_class_weights = None + logger.info("start training") model.fit([domain_tr, flow_tr], [client_tr, server_tr], @@ -158,7 +178,8 @@ def main_train(param=None): epochs=args.epochs, callbacks=callbacks, shuffle=True, - validation_split=0.2) + validation_split=0.2, + class_weight=custom_class_weights) logger.info("save embedding") embedding.save(args.embedding_model) @@ -167,11 +188,9 @@ def main_test(): domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data, args.domain_length, args.window) clf = load_model(args.clf_model, custom_objects=models.get_metrics()) - stats = clf.evaluate([domain_val, flow_val], - [client_val, server_val], - batch_size=args.batch_size) - # logger.info(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}") - logger.info(stats) + # stats = clf.evaluate([domain_val, flow_val], + # [client_val, server_val], + # batch_size=args.batch_size) y_pred = clf.predict([domain_val, flow_val], batch_size=args.batch_size) np.save(args.future_prediction, y_pred) @@ -197,6 +216,12 @@ def main_visualization(): logger.info("plot roc curve") visualize.plot_roc_curve(client_val.value, client_pred, "{}/client_roc.png".format(args.model_path)) visualize.plot_roc_curve(server_val.value, server_pred, "{}/server_roc.png".format(args.model_path)) + visualize.plot_confusion_matrix(client_val.value.argmax(1), client_pred.argmax(1), + "{}/client_cov.png".format(args.model_path), + normalize=False, title="Client Confusion Matrix") + visualize.plot_confusion_matrix(server_val.value.argmax(1), server_pred.argmax(1), + "{}/server_cov.png".format(args.model_path), + normalize=False, title="Server Confusion Matrix") def main_score(): diff --git a/visualize.py b/visualize.py index 229e35e..fbe69e2 100644 --- a/visualize.py +++ b/visualize.py @@ -90,10 +90,10 @@ def plot_roc_curve(mask, prediction, path): print("roc_auc", roc_auc) -def plot_confusion_matrix(y_true, y_pred, +def plot_confusion_matrix(y_true, y_pred, path, normalize=False, title='Confusion matrix', - cmap="Blues"): + cmap="Blues", dpi=600): """ This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. @@ -125,6 +125,8 @@ def plot_confusion_matrix(y_true, y_pred, plt.tight_layout() plt.ylabel('True label') plt.xlabel('Predicted label') + plt.savefig(path, dpi=dpi) + plt.close() def plot_training_curve(logs, key, path, dpi=600):