2017-07-29 19:47:02 +02:00
|
|
|
import os
|
|
|
|
|
2017-07-30 13:47:11 +02:00
|
|
|
import numpy as np
|
|
|
|
from sklearn.utils import class_weight
|
|
|
|
|
2017-07-29 19:47:02 +02:00
|
|
|
|
|
|
|
def exists_or_make_path(p):
|
|
|
|
if not os.path.exists(p):
|
|
|
|
os.makedirs(p)
|
2017-07-30 13:47:11 +02:00
|
|
|
|
|
|
|
|
|
|
|
def get_custom_class_weights(client, server):
|
|
|
|
client_class_weight = class_weight.compute_class_weight('balanced', np.unique(client), client)
|
|
|
|
server_class_weight = class_weight.compute_class_weight('balanced', np.unique(server), server)
|
|
|
|
return {
|
|
|
|
"client": client_class_weight,
|
|
|
|
"server": server_class_weight
|
|
|
|
}
|