ma_cisco_malware/utils.py

47 lines
1.2 KiB
Python

import os
from operator import itemgetter
import joblib
import numpy as np
from keras.models import load_model as load_keras_model
from sklearn.utils import class_weight
def exists_or_make_path(p):
if not os.path.exists(p):
os.makedirs(p)
def get_custom_class_weights(client, server):
return {
"client": class_weight.compute_class_weight('balanced', np.unique(client), client),
"server": class_weight.compute_class_weight('balanced', np.unique(server), server)
}
def get_custom_sample_weights(client, server):
return {
"client": class_weight.compute_sample_weight("balanced", client),
"server": class_weight.compute_sample_weight("balanced", server)
}
def load_ordered_hyperband_results(path):
results = joblib.load(path)
return sorted(results, itemgetter("loss"))
def load_model(path, custom_objects=None):
clf = load_keras_model(path, custom_objects)
try:
embd = clf.get_layer("domain_cnn").layer
except Exception:
# in some version i forgot to specify domain_cnn
# this bug fix is for certain compatibility
try:
embd = clf.layers[1].layer
except Exception:
embd = clf.get_layer("domain_cnn")
return embd, clf