add sample weight metrics to fit function
This commit is contained in:
parent
e8473048cb
commit
33063f3081
19
main.py
19
main.py
@ -15,7 +15,7 @@ import models
|
||||
# create logger
|
||||
import visualize
|
||||
from arguments import get_model_args
|
||||
from utils import exists_or_make_path, get_custom_class_weights, load_model
|
||||
from utils import exists_or_make_path, get_custom_class_weights, get_custom_sample_weights, load_model
|
||||
|
||||
logger = logging.getLogger('logger')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
@ -166,6 +166,14 @@ def main_train(param=None):
|
||||
logger.info("class weights: set default")
|
||||
custom_class_weights = None
|
||||
|
||||
if args.sample_weights:
|
||||
logger.info("class weights: compute custom weights")
|
||||
custom_sample_weights = get_custom_sample_weights(client_tr.value, server_tr)
|
||||
logger.info(custom_class_weights)
|
||||
else:
|
||||
logger.info("class weights: set default")
|
||||
custom_sample_weights = None
|
||||
|
||||
if not param:
|
||||
param = PARAMS
|
||||
logger.info(f"Generator model with params: {param}")
|
||||
@ -205,7 +213,8 @@ def main_train(param=None):
|
||||
model.fit(features, labels,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
class_weight=custom_class_weights)
|
||||
class_weight=custom_class_weights,
|
||||
sample_weight=custom_sample_weights)
|
||||
|
||||
logger.info("fix server model")
|
||||
model.get_layer("domain_cnn").trainable = False
|
||||
@ -227,7 +236,8 @@ def main_train(param=None):
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
callbacks=callbacks,
|
||||
class_weight=custom_class_weights)
|
||||
class_weight=custom_class_weights,
|
||||
sample_weight=custom_sample_weights)
|
||||
|
||||
|
||||
def main_retrain():
|
||||
@ -406,7 +416,7 @@ def main_visualization():
|
||||
def plot_embedding(model_path, domain_embedding, data, domain_length):
|
||||
logger.info("visualize embedding")
|
||||
domain_encs, labels = dataset.load_or_generate_domains(data, domain_length)
|
||||
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.pdf".format(model_path), method="svd")
|
||||
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd")
|
||||
|
||||
|
||||
def main_visualize_all():
|
||||
@ -641,7 +651,6 @@ def train_server_only():
|
||||
model.fit(features, labels,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
validation_split=0.2,
|
||||
callbacks=callbacks)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user