diff --git a/utils.py b/utils.py index 187e9d3..636e3a7 100644 --- a/utils.py +++ b/utils.py @@ -16,3 +16,7 @@ def get_custom_class_weights(client, server): "client": client_class_weight, "server": server_class_weight } + + +def get_custom_sample_weights(client, server): + return class_weight.compute_sample_weight("balanced", np.vstack((client, server)).T)