47 lines
1.3 KiB
Python
47 lines
1.3 KiB
Python
import os
|
|
from operator import itemgetter
|
|
|
|
import joblib
|
|
import numpy as np
|
|
from keras.models import load_model as load_keras_model
|
|
from sklearn.utils import class_weight
|
|
|
|
|
|
def exists_or_make_path(p):
|
|
if not os.path.exists(p):
|
|
os.makedirs(p)
|
|
|
|
|
|
def get_custom_class_weights(client, server):
|
|
client_class_weight = class_weight.compute_class_weight('balanced', np.unique(client), client)
|
|
server_class_weight = class_weight.compute_class_weight('balanced', np.unique(server), server)
|
|
return {
|
|
"client": client_class_weight,
|
|
"server": server_class_weight
|
|
}
|
|
|
|
|
|
def get_custom_sample_weights(client, server):
|
|
return class_weight.compute_sample_weight("balanced", np.vstack((client, server)).T)
|
|
|
|
|
|
def load_ordered_hyperband_results(path):
|
|
results = joblib.load(path)
|
|
return sorted(results, itemgetter("loss"))
|
|
|
|
|
|
def load_model(path, custom_objects=None):
|
|
clf = load_keras_model(path, custom_objects)
|
|
try:
|
|
embd = clf.get_layer("domain_cnn").layer
|
|
except Exception:
|
|
# in some version i forgot to specify domain_cnn
|
|
# this bug fix is for certain compatibility
|
|
try:
|
|
embd = clf.layers[1].layer
|
|
except Exception:
|
|
embd = clf.get_layer("domain_cnn")
|
|
|
|
|
|
return embd, clf
|