extract weighting function
This commit is contained in:
parent
9ce11e4db4
commit
4fc2f0c925
@ -263,8 +263,11 @@ def load_or_generate_domains(train_data, domain_length):
|
||||
fn = f"{train_data}_domains.gz"
|
||||
|
||||
try:
|
||||
logger.info(f"Load file {fn}.")
|
||||
user_flow_df = pd.read_csv(fn)
|
||||
logger.info(f"File successfully loaded.")
|
||||
except FileNotFoundError:
|
||||
logger.info(f"File {fn} not found, recreate.")
|
||||
user_flow_df = get_user_flow_data(train_data)
|
||||
# user_flow_df.reset_index(inplace=True)
|
||||
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0,
|
||||
@ -279,6 +282,7 @@ def load_or_generate_domains(train_data, domain_length):
|
||||
|
||||
user_flow_df.to_csv(fn, compression="gzip")
|
||||
|
||||
logger.info(f"Extract features from domains")
|
||||
domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, domain_length))
|
||||
domain_encs = np.stack(domain_encs)
|
||||
|
||||
|
51
main.py
51
main.py
@ -197,6 +197,26 @@ def load_data(data, domain_length, window_size, model_type):
|
||||
return domain_tr, flow_tr, client_tr, server_tr
|
||||
|
||||
|
||||
def get_weighting(class_weights, sample_weights, client, server):
|
||||
if class_weights:
|
||||
logger.info("class weights: compute custom weights")
|
||||
custom_class_weights = get_custom_class_weights(client, server)
|
||||
logger.info(custom_class_weights)
|
||||
else:
|
||||
logger.info("class weights: set default")
|
||||
custom_class_weights = None
|
||||
|
||||
if sample_weights:
|
||||
logger.info("class weights: compute custom weights")
|
||||
custom_sample_weights = get_custom_sample_weights(client, server)
|
||||
logger.info(custom_sample_weights)
|
||||
else:
|
||||
logger.info("class weights: set default")
|
||||
custom_sample_weights = None
|
||||
|
||||
return custom_class_weights, custom_sample_weights
|
||||
|
||||
|
||||
def main_train(param=None):
|
||||
logger.info(f"Create model path {args.model_path}")
|
||||
exists_or_make_path(args.model_path)
|
||||
@ -220,6 +240,10 @@ def main_train(param=None):
|
||||
if not param:
|
||||
param = PARAMS
|
||||
|
||||
# custom class or sample weights
|
||||
custom_class_weights, custom_sample_weights = get_weighting(args.class_weights, args.sample_weights,
|
||||
client_tr.value, server_tr)
|
||||
|
||||
for i in range(args.runs):
|
||||
model_path = os.path.join(args.model_path, f"clf_{i}.h5")
|
||||
train_log_path = os.path.join(args.model_path, f"train_{i}.log.csv")
|
||||
@ -238,23 +262,6 @@ def main_train(param=None):
|
||||
verbose=False))
|
||||
custom_metrics = models.get_metric_functions()
|
||||
|
||||
# custom class or sample weights
|
||||
if args.class_weights:
|
||||
logger.info("class weights: compute custom weights")
|
||||
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
|
||||
logger.info(custom_class_weights)
|
||||
else:
|
||||
logger.info("class weights: set default")
|
||||
custom_class_weights = None
|
||||
|
||||
if args.sample_weights:
|
||||
logger.info("class weights: compute custom weights")
|
||||
custom_sample_weights = get_custom_sample_weights(client_tr.value, server_tr)
|
||||
logger.info(custom_class_weights)
|
||||
else:
|
||||
logger.info("class weights: set default")
|
||||
custom_sample_weights = None
|
||||
|
||||
logger.info(f"Generator model with params: {param}")
|
||||
model = models.get_models_by_params(param)
|
||||
|
||||
@ -372,10 +379,14 @@ def main_retrain():
|
||||
|
||||
|
||||
def main_test():
|
||||
logger.info("start test: load data")
|
||||
logger.info("load test data")
|
||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
|
||||
logger.info("load test domains")
|
||||
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
||||
|
||||
def get_dir(path):
|
||||
return os.path.split(os.path.normpath(path))
|
||||
|
||||
results = {}
|
||||
for model_path in args.model_paths:
|
||||
file = get_dir(model_path)[1]
|
||||
@ -398,8 +409,8 @@ def main_test():
|
||||
|
||||
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||
results["domain_embds"] = domain_embeddings
|
||||
|
||||
dataset.save_predictions(get_dir(model_path)[0], results)
|
||||
# store results every round - safety first!
|
||||
dataset.save_predictions(get_dir(model_path)[0], results)
|
||||
|
||||
|
||||
def main_visualization():
|
||||
|
@ -1,8 +1,17 @@
|
||||
from collections import namedtuple
|
||||
|
||||
from keras.models import Model
|
||||
|
||||
from . import networks
|
||||
from .metrics import *
|
||||
|
||||
NetworkParameters = namedtuple("NetworkParameters", [
|
||||
"type", "flow_features", "window_size", "domain_length", "output",
|
||||
"embedding_size",
|
||||
"domain_filter", "domain_kernel", "domain_dense", "domain_dropout",
|
||||
"main_filter", "main_kernel", "main_dense", "main_dropout",
|
||||
])
|
||||
|
||||
|
||||
def create_model(model, output_type):
|
||||
if output_type == "both":
|
||||
@ -14,6 +23,7 @@ def create_model(model, output_type):
|
||||
|
||||
|
||||
def get_models_by_params(params: dict):
|
||||
K.clear_session()
|
||||
# decomposing param section
|
||||
# mainly embedding model
|
||||
network_type = params.get("type")
|
||||
|
Loading…
Reference in New Issue
Block a user