add sample weight metrics to fit function

This commit is contained in:
René Knaebel 2017-10-08 22:09:09 +02:00
parent e8473048cb
commit 33063f3081

19
main.py
View File

@ -15,7 +15,7 @@ import models
# create logger # create logger
import visualize import visualize
from arguments import get_model_args from arguments import get_model_args
from utils import exists_or_make_path, get_custom_class_weights, load_model from utils import exists_or_make_path, get_custom_class_weights, get_custom_sample_weights, load_model
logger = logging.getLogger('logger') logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -166,6 +166,14 @@ def main_train(param=None):
logger.info("class weights: set default") logger.info("class weights: set default")
custom_class_weights = None custom_class_weights = None
if args.sample_weights:
logger.info("class weights: compute custom weights")
custom_sample_weights = get_custom_sample_weights(client_tr.value, server_tr)
logger.info(custom_class_weights)
else:
logger.info("class weights: set default")
custom_sample_weights = None
if not param: if not param:
param = PARAMS param = PARAMS
logger.info(f"Generator model with params: {param}") logger.info(f"Generator model with params: {param}")
@ -205,7 +213,8 @@ def main_train(param=None):
model.fit(features, labels, model.fit(features, labels,
batch_size=args.batch_size, batch_size=args.batch_size,
epochs=args.epochs, epochs=args.epochs,
class_weight=custom_class_weights) class_weight=custom_class_weights,
sample_weight=custom_sample_weights)
logger.info("fix server model") logger.info("fix server model")
model.get_layer("domain_cnn").trainable = False model.get_layer("domain_cnn").trainable = False
@ -227,7 +236,8 @@ def main_train(param=None):
batch_size=args.batch_size, batch_size=args.batch_size,
epochs=args.epochs, epochs=args.epochs,
callbacks=callbacks, callbacks=callbacks,
class_weight=custom_class_weights) class_weight=custom_class_weights,
sample_weight=custom_sample_weights)
def main_retrain(): def main_retrain():
@ -406,7 +416,7 @@ def main_visualization():
def plot_embedding(model_path, domain_embedding, data, domain_length): def plot_embedding(model_path, domain_embedding, data, domain_length):
logger.info("visualize embedding") logger.info("visualize embedding")
domain_encs, labels = dataset.load_or_generate_domains(data, domain_length) domain_encs, labels = dataset.load_or_generate_domains(data, domain_length)
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.pdf".format(model_path), method="svd") visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd")
def main_visualize_all(): def main_visualize_all():
@ -641,7 +651,6 @@ def train_server_only():
model.fit(features, labels, model.fit(features, labels,
batch_size=args.batch_size, batch_size=args.batch_size,
epochs=args.epochs, epochs=args.epochs,
validation_split=0.2,
callbacks=callbacks) callbacks=callbacks)