add custom class weights based on sklearn balance

This commit is contained in:
René Knaebel 2017-07-14 15:57:52 +02:00
parent b35f23e518
commit 6b787792db
3 changed files with 36 additions and 8 deletions

View File

@ -61,6 +61,7 @@ parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
# parser.add_argument("--tmp", action="store_true", dest="tmp")
#
parser.add_argument("--stop_early", action="store_true", dest="stop_early")
parser.add_argument("--balanced_weights", action="store_true", dest="class_weights")
parser.add_argument("--gpu", action="store_true", dest="gpu")

37
main.py
View File

@ -7,6 +7,7 @@ import pandas as pd
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
from keras.models import load_model
from sklearn.utils import class_weight
import arguments
import dataset
@ -101,6 +102,17 @@ def main_hyperband():
json.dump(results, open("hyperband.json"))
def get_custom_class_weights(client_tr, server_tr):
client = client_tr.value.argmax(1)
server = server_tr.value.argmax(1)
client_class_weight = class_weight.compute_class_weight('balanced', np.unique(client), client)
server_class_weight = class_weight.compute_class_weight('balanced', np.unique(server), server)
return {
"client": client_class_weight,
"server": server_class_weight
}
def main_train(param=None):
exists_or_make_path(args.model_path)
@ -151,6 +163,14 @@ def main_train(param=None):
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'] + custom_metrics)
if args.class_weights:
logger.info("class weights: compute custom weights")
custom_class_weights = get_custom_class_weights(client_tr, server_tr)
else:
logger.info("class weights: set default")
custom_class_weights = None
logger.info("start training")
model.fit([domain_tr, flow_tr],
[client_tr, server_tr],
@ -158,7 +178,8 @@ def main_train(param=None):
epochs=args.epochs,
callbacks=callbacks,
shuffle=True,
validation_split=0.2)
validation_split=0.2,
class_weight=custom_class_weights)
logger.info("save embedding")
embedding.save(args.embedding_model)
@ -167,11 +188,9 @@ def main_test():
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
args.domain_length, args.window)
clf = load_model(args.clf_model, custom_objects=models.get_metrics())
stats = clf.evaluate([domain_val, flow_val],
[client_val, server_val],
batch_size=args.batch_size)
# logger.info(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}")
logger.info(stats)
# stats = clf.evaluate([domain_val, flow_val],
# [client_val, server_val],
# batch_size=args.batch_size)
y_pred = clf.predict([domain_val, flow_val],
batch_size=args.batch_size)
np.save(args.future_prediction, y_pred)
@ -197,6 +216,12 @@ def main_visualization():
logger.info("plot roc curve")
visualize.plot_roc_curve(client_val.value, client_pred, "{}/client_roc.png".format(args.model_path))
visualize.plot_roc_curve(server_val.value, server_pred, "{}/server_roc.png".format(args.model_path))
visualize.plot_confusion_matrix(client_val.value.argmax(1), client_pred.argmax(1),
"{}/client_cov.png".format(args.model_path),
normalize=False, title="Client Confusion Matrix")
visualize.plot_confusion_matrix(server_val.value.argmax(1), server_pred.argmax(1),
"{}/server_cov.png".format(args.model_path),
normalize=False, title="Server Confusion Matrix")
def main_score():

View File

@ -90,10 +90,10 @@ def plot_roc_curve(mask, prediction, path):
print("roc_auc", roc_auc)
def plot_confusion_matrix(y_true, y_pred,
def plot_confusion_matrix(y_true, y_pred, path,
normalize=False,
title='Confusion matrix',
cmap="Blues"):
cmap="Blues", dpi=600):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
@ -125,6 +125,8 @@ def plot_confusion_matrix(y_true, y_pred,
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.savefig(path, dpi=dpi)
plt.close()
def plot_training_curve(logs, key, path, dpi=600):