ma_cisco_malware/utils.py

30 lines
801 B
Python

import os
from operator import itemgetter
import joblib
import numpy as np
from sklearn.utils import class_weight
def exists_or_make_path(p):
if not os.path.exists(p):
os.makedirs(p)
def get_custom_class_weights(client, server):
client_class_weight = class_weight.compute_class_weight('balanced', np.unique(client), client)
server_class_weight = class_weight.compute_class_weight('balanced', np.unique(server), server)
return {
"client": client_class_weight,
"server": server_class_weight
}
def get_custom_sample_weights(client, server):
return class_weight.compute_sample_weight("balanced", np.vstack((client, server)).T)
def load_ordered_hyperband_results(path):
results = joblib.load(path)
return sorted(results, itemgetter("loss"))