add custom class weights based on sklearn balance
This commit is contained in:
		@@ -61,6 +61,7 @@ parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
 | 
			
		||||
# parser.add_argument("--tmp", action="store_true", dest="tmp")
 | 
			
		||||
#
 | 
			
		||||
parser.add_argument("--stop_early", action="store_true", dest="stop_early")
 | 
			
		||||
parser.add_argument("--balanced_weights", action="store_true", dest="class_weights")
 | 
			
		||||
parser.add_argument("--gpu", action="store_true", dest="gpu")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										37
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								main.py
									
									
									
									
									
								
							@@ -7,6 +7,7 @@ import pandas as pd
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
 | 
			
		||||
from keras.models import load_model
 | 
			
		||||
from sklearn.utils import class_weight
 | 
			
		||||
 | 
			
		||||
import arguments
 | 
			
		||||
import dataset
 | 
			
		||||
@@ -101,6 +102,17 @@ def main_hyperband():
 | 
			
		||||
    json.dump(results, open("hyperband.json"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_custom_class_weights(client_tr, server_tr):
 | 
			
		||||
    client = client_tr.value.argmax(1)
 | 
			
		||||
    server = server_tr.value.argmax(1)
 | 
			
		||||
    client_class_weight = class_weight.compute_class_weight('balanced', np.unique(client), client)
 | 
			
		||||
    server_class_weight = class_weight.compute_class_weight('balanced', np.unique(server), server)
 | 
			
		||||
    return {
 | 
			
		||||
        "client": client_class_weight,
 | 
			
		||||
        "server": server_class_weight
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main_train(param=None):
 | 
			
		||||
    exists_or_make_path(args.model_path)
 | 
			
		||||
 | 
			
		||||
@@ -151,6 +163,14 @@ def main_train(param=None):
 | 
			
		||||
    model.compile(optimizer='adam',
 | 
			
		||||
                  loss='categorical_crossentropy',
 | 
			
		||||
                  metrics=['accuracy'] + custom_metrics)
 | 
			
		||||
 | 
			
		||||
    if args.class_weights:
 | 
			
		||||
        logger.info("class weights: compute custom weights")
 | 
			
		||||
        custom_class_weights = get_custom_class_weights(client_tr, server_tr)
 | 
			
		||||
    else:
 | 
			
		||||
        logger.info("class weights: set default")
 | 
			
		||||
        custom_class_weights = None
 | 
			
		||||
 | 
			
		||||
    logger.info("start training")
 | 
			
		||||
    model.fit([domain_tr, flow_tr],
 | 
			
		||||
              [client_tr, server_tr],
 | 
			
		||||
@@ -158,7 +178,8 @@ def main_train(param=None):
 | 
			
		||||
              epochs=args.epochs,
 | 
			
		||||
              callbacks=callbacks,
 | 
			
		||||
              shuffle=True,
 | 
			
		||||
              validation_split=0.2)
 | 
			
		||||
              validation_split=0.2,
 | 
			
		||||
              class_weight=custom_class_weights)
 | 
			
		||||
    logger.info("save embedding")
 | 
			
		||||
    embedding.save(args.embedding_model)
 | 
			
		||||
 | 
			
		||||
@@ -167,11 +188,9 @@ def main_test():
 | 
			
		||||
    domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
 | 
			
		||||
                                                                           args.domain_length, args.window)
 | 
			
		||||
    clf = load_model(args.clf_model, custom_objects=models.get_metrics())
 | 
			
		||||
    stats = clf.evaluate([domain_val, flow_val],
 | 
			
		||||
                         [client_val, server_val],
 | 
			
		||||
                         batch_size=args.batch_size)
 | 
			
		||||
    # logger.info(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}")
 | 
			
		||||
    logger.info(stats)
 | 
			
		||||
    # stats = clf.evaluate([domain_val, flow_val],
 | 
			
		||||
    #                      [client_val, server_val],
 | 
			
		||||
    #                      batch_size=args.batch_size)
 | 
			
		||||
    y_pred = clf.predict([domain_val, flow_val],
 | 
			
		||||
                         batch_size=args.batch_size)
 | 
			
		||||
    np.save(args.future_prediction, y_pred)
 | 
			
		||||
@@ -197,6 +216,12 @@ def main_visualization():
 | 
			
		||||
    logger.info("plot roc curve")
 | 
			
		||||
    visualize.plot_roc_curve(client_val.value, client_pred, "{}/client_roc.png".format(args.model_path))
 | 
			
		||||
    visualize.plot_roc_curve(server_val.value, server_pred, "{}/server_roc.png".format(args.model_path))
 | 
			
		||||
    visualize.plot_confusion_matrix(client_val.value.argmax(1), client_pred.argmax(1),
 | 
			
		||||
                                    "{}/client_cov.png".format(args.model_path),
 | 
			
		||||
                                    normalize=False, title="Client Confusion Matrix")
 | 
			
		||||
    visualize.plot_confusion_matrix(server_val.value.argmax(1), server_pred.argmax(1),
 | 
			
		||||
                                    "{}/server_cov.png".format(args.model_path),
 | 
			
		||||
                                    normalize=False, title="Server Confusion Matrix")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main_score():
 | 
			
		||||
 
 | 
			
		||||
@@ -90,10 +90,10 @@ def plot_roc_curve(mask, prediction, path):
 | 
			
		||||
    print("roc_auc", roc_auc)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def plot_confusion_matrix(y_true, y_pred,
 | 
			
		||||
def plot_confusion_matrix(y_true, y_pred, path,
 | 
			
		||||
                          normalize=False,
 | 
			
		||||
                          title='Confusion matrix',
 | 
			
		||||
                          cmap="Blues"):
 | 
			
		||||
                          cmap="Blues", dpi=600):
 | 
			
		||||
    """
 | 
			
		||||
    This function prints and plots the confusion matrix.
 | 
			
		||||
    Normalization can be applied by setting `normalize=True`.
 | 
			
		||||
@@ -125,6 +125,8 @@ def plot_confusion_matrix(y_true, y_pred,
 | 
			
		||||
    plt.tight_layout()
 | 
			
		||||
    plt.ylabel('True label')
 | 
			
		||||
    plt.xlabel('Predicted label')
 | 
			
		||||
    plt.savefig(path, dpi=dpi)
 | 
			
		||||
    plt.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def plot_training_curve(logs, key, path, dpi=600):
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user