add custom class weights based on sklearn balance
This commit is contained in:
parent
b35f23e518
commit
6b787792db
|
@ -61,6 +61,7 @@ parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
|
||||||
# parser.add_argument("--tmp", action="store_true", dest="tmp")
|
# parser.add_argument("--tmp", action="store_true", dest="tmp")
|
||||||
#
|
#
|
||||||
parser.add_argument("--stop_early", action="store_true", dest="stop_early")
|
parser.add_argument("--stop_early", action="store_true", dest="stop_early")
|
||||||
|
parser.add_argument("--balanced_weights", action="store_true", dest="class_weights")
|
||||||
parser.add_argument("--gpu", action="store_true", dest="gpu")
|
parser.add_argument("--gpu", action="store_true", dest="gpu")
|
||||||
|
|
||||||
|
|
||||||
|
|
37
main.py
37
main.py
|
@ -7,6 +7,7 @@ import pandas as pd
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
|
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
|
||||||
from keras.models import load_model
|
from keras.models import load_model
|
||||||
|
from sklearn.utils import class_weight
|
||||||
|
|
||||||
import arguments
|
import arguments
|
||||||
import dataset
|
import dataset
|
||||||
|
@ -101,6 +102,17 @@ def main_hyperband():
|
||||||
json.dump(results, open("hyperband.json"))
|
json.dump(results, open("hyperband.json"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_class_weights(client_tr, server_tr):
|
||||||
|
client = client_tr.value.argmax(1)
|
||||||
|
server = server_tr.value.argmax(1)
|
||||||
|
client_class_weight = class_weight.compute_class_weight('balanced', np.unique(client), client)
|
||||||
|
server_class_weight = class_weight.compute_class_weight('balanced', np.unique(server), server)
|
||||||
|
return {
|
||||||
|
"client": client_class_weight,
|
||||||
|
"server": server_class_weight
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def main_train(param=None):
|
def main_train(param=None):
|
||||||
exists_or_make_path(args.model_path)
|
exists_or_make_path(args.model_path)
|
||||||
|
|
||||||
|
@ -151,6 +163,14 @@ def main_train(param=None):
|
||||||
model.compile(optimizer='adam',
|
model.compile(optimizer='adam',
|
||||||
loss='categorical_crossentropy',
|
loss='categorical_crossentropy',
|
||||||
metrics=['accuracy'] + custom_metrics)
|
metrics=['accuracy'] + custom_metrics)
|
||||||
|
|
||||||
|
if args.class_weights:
|
||||||
|
logger.info("class weights: compute custom weights")
|
||||||
|
custom_class_weights = get_custom_class_weights(client_tr, server_tr)
|
||||||
|
else:
|
||||||
|
logger.info("class weights: set default")
|
||||||
|
custom_class_weights = None
|
||||||
|
|
||||||
logger.info("start training")
|
logger.info("start training")
|
||||||
model.fit([domain_tr, flow_tr],
|
model.fit([domain_tr, flow_tr],
|
||||||
[client_tr, server_tr],
|
[client_tr, server_tr],
|
||||||
|
@ -158,7 +178,8 @@ def main_train(param=None):
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
validation_split=0.2)
|
validation_split=0.2,
|
||||||
|
class_weight=custom_class_weights)
|
||||||
logger.info("save embedding")
|
logger.info("save embedding")
|
||||||
embedding.save(args.embedding_model)
|
embedding.save(args.embedding_model)
|
||||||
|
|
||||||
|
@ -167,11 +188,9 @@ def main_test():
|
||||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||||
args.domain_length, args.window)
|
args.domain_length, args.window)
|
||||||
clf = load_model(args.clf_model, custom_objects=models.get_metrics())
|
clf = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||||
stats = clf.evaluate([domain_val, flow_val],
|
# stats = clf.evaluate([domain_val, flow_val],
|
||||||
[client_val, server_val],
|
# [client_val, server_val],
|
||||||
batch_size=args.batch_size)
|
# batch_size=args.batch_size)
|
||||||
# logger.info(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}")
|
|
||||||
logger.info(stats)
|
|
||||||
y_pred = clf.predict([domain_val, flow_val],
|
y_pred = clf.predict([domain_val, flow_val],
|
||||||
batch_size=args.batch_size)
|
batch_size=args.batch_size)
|
||||||
np.save(args.future_prediction, y_pred)
|
np.save(args.future_prediction, y_pred)
|
||||||
|
@ -197,6 +216,12 @@ def main_visualization():
|
||||||
logger.info("plot roc curve")
|
logger.info("plot roc curve")
|
||||||
visualize.plot_roc_curve(client_val.value, client_pred, "{}/client_roc.png".format(args.model_path))
|
visualize.plot_roc_curve(client_val.value, client_pred, "{}/client_roc.png".format(args.model_path))
|
||||||
visualize.plot_roc_curve(server_val.value, server_pred, "{}/server_roc.png".format(args.model_path))
|
visualize.plot_roc_curve(server_val.value, server_pred, "{}/server_roc.png".format(args.model_path))
|
||||||
|
visualize.plot_confusion_matrix(client_val.value.argmax(1), client_pred.argmax(1),
|
||||||
|
"{}/client_cov.png".format(args.model_path),
|
||||||
|
normalize=False, title="Client Confusion Matrix")
|
||||||
|
visualize.plot_confusion_matrix(server_val.value.argmax(1), server_pred.argmax(1),
|
||||||
|
"{}/server_cov.png".format(args.model_path),
|
||||||
|
normalize=False, title="Server Confusion Matrix")
|
||||||
|
|
||||||
|
|
||||||
def main_score():
|
def main_score():
|
||||||
|
|
|
@ -90,10 +90,10 @@ def plot_roc_curve(mask, prediction, path):
|
||||||
print("roc_auc", roc_auc)
|
print("roc_auc", roc_auc)
|
||||||
|
|
||||||
|
|
||||||
def plot_confusion_matrix(y_true, y_pred,
|
def plot_confusion_matrix(y_true, y_pred, path,
|
||||||
normalize=False,
|
normalize=False,
|
||||||
title='Confusion matrix',
|
title='Confusion matrix',
|
||||||
cmap="Blues"):
|
cmap="Blues", dpi=600):
|
||||||
"""
|
"""
|
||||||
This function prints and plots the confusion matrix.
|
This function prints and plots the confusion matrix.
|
||||||
Normalization can be applied by setting `normalize=True`.
|
Normalization can be applied by setting `normalize=True`.
|
||||||
|
@ -125,6 +125,8 @@ def plot_confusion_matrix(y_true, y_pred,
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.ylabel('True label')
|
plt.ylabel('True label')
|
||||||
plt.xlabel('Predicted label')
|
plt.xlabel('Predicted label')
|
||||||
|
plt.savefig(path, dpi=dpi)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
def plot_training_curve(logs, key, path, dpi=600):
|
def plot_training_curve(logs, key, path, dpi=600):
|
||||||
|
|
Loading…
Reference in New Issue