diff --git a/main.py b/main.py index 084c441..9e8eb09 100644 --- a/main.py +++ b/main.py @@ -15,7 +15,7 @@ import models # create logger import visualize from arguments import get_model_args -from utils import exists_or_make_path, get_custom_class_weights, load_model +from utils import exists_or_make_path, get_custom_class_weights, get_custom_sample_weights, load_model logger = logging.getLogger('logger') logger.setLevel(logging.DEBUG) @@ -166,6 +166,14 @@ def main_train(param=None): logger.info("class weights: set default") custom_class_weights = None + if args.sample_weights: + logger.info("class weights: compute custom weights") + custom_sample_weights = get_custom_sample_weights(client_tr.value, server_tr) + logger.info(custom_class_weights) + else: + logger.info("class weights: set default") + custom_sample_weights = None + if not param: param = PARAMS logger.info(f"Generator model with params: {param}") @@ -205,7 +213,8 @@ def main_train(param=None): model.fit(features, labels, batch_size=args.batch_size, epochs=args.epochs, - class_weight=custom_class_weights) + class_weight=custom_class_weights, + sample_weight=custom_sample_weights) logger.info("fix server model") model.get_layer("domain_cnn").trainable = False @@ -227,7 +236,8 @@ def main_train(param=None): batch_size=args.batch_size, epochs=args.epochs, callbacks=callbacks, - class_weight=custom_class_weights) + class_weight=custom_class_weights, + sample_weight=custom_sample_weights) def main_retrain(): @@ -406,7 +416,7 @@ def main_visualization(): def plot_embedding(model_path, domain_embedding, data, domain_length): logger.info("visualize embedding") domain_encs, labels = dataset.load_or_generate_domains(data, domain_length) - visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.pdf".format(model_path), method="svd") + visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd") def main_visualize_all(): @@ -641,7 +651,6 @@ def train_server_only(): model.fit(features, labels, batch_size=args.batch_size, epochs=args.epochs, - validation_split=0.2, callbacks=callbacks)