extract weighting function
This commit is contained in:
parent
9ce11e4db4
commit
4fc2f0c925
@ -263,8 +263,11 @@ def load_or_generate_domains(train_data, domain_length):
|
|||||||
fn = f"{train_data}_domains.gz"
|
fn = f"{train_data}_domains.gz"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"Load file {fn}.")
|
||||||
user_flow_df = pd.read_csv(fn)
|
user_flow_df = pd.read_csv(fn)
|
||||||
|
logger.info(f"File successfully loaded.")
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
|
logger.info(f"File {fn} not found, recreate.")
|
||||||
user_flow_df = get_user_flow_data(train_data)
|
user_flow_df = get_user_flow_data(train_data)
|
||||||
# user_flow_df.reset_index(inplace=True)
|
# user_flow_df.reset_index(inplace=True)
|
||||||
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0,
|
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0,
|
||||||
@ -279,6 +282,7 @@ def load_or_generate_domains(train_data, domain_length):
|
|||||||
|
|
||||||
user_flow_df.to_csv(fn, compression="gzip")
|
user_flow_df.to_csv(fn, compression="gzip")
|
||||||
|
|
||||||
|
logger.info(f"Extract features from domains")
|
||||||
domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, domain_length))
|
domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, domain_length))
|
||||||
domain_encs = np.stack(domain_encs)
|
domain_encs = np.stack(domain_encs)
|
||||||
|
|
||||||
|
51
main.py
51
main.py
@ -197,6 +197,26 @@ def load_data(data, domain_length, window_size, model_type):
|
|||||||
return domain_tr, flow_tr, client_tr, server_tr
|
return domain_tr, flow_tr, client_tr, server_tr
|
||||||
|
|
||||||
|
|
||||||
|
def get_weighting(class_weights, sample_weights, client, server):
|
||||||
|
if class_weights:
|
||||||
|
logger.info("class weights: compute custom weights")
|
||||||
|
custom_class_weights = get_custom_class_weights(client, server)
|
||||||
|
logger.info(custom_class_weights)
|
||||||
|
else:
|
||||||
|
logger.info("class weights: set default")
|
||||||
|
custom_class_weights = None
|
||||||
|
|
||||||
|
if sample_weights:
|
||||||
|
logger.info("class weights: compute custom weights")
|
||||||
|
custom_sample_weights = get_custom_sample_weights(client, server)
|
||||||
|
logger.info(custom_sample_weights)
|
||||||
|
else:
|
||||||
|
logger.info("class weights: set default")
|
||||||
|
custom_sample_weights = None
|
||||||
|
|
||||||
|
return custom_class_weights, custom_sample_weights
|
||||||
|
|
||||||
|
|
||||||
def main_train(param=None):
|
def main_train(param=None):
|
||||||
logger.info(f"Create model path {args.model_path}")
|
logger.info(f"Create model path {args.model_path}")
|
||||||
exists_or_make_path(args.model_path)
|
exists_or_make_path(args.model_path)
|
||||||
@ -220,6 +240,10 @@ def main_train(param=None):
|
|||||||
if not param:
|
if not param:
|
||||||
param = PARAMS
|
param = PARAMS
|
||||||
|
|
||||||
|
# custom class or sample weights
|
||||||
|
custom_class_weights, custom_sample_weights = get_weighting(args.class_weights, args.sample_weights,
|
||||||
|
client_tr.value, server_tr)
|
||||||
|
|
||||||
for i in range(args.runs):
|
for i in range(args.runs):
|
||||||
model_path = os.path.join(args.model_path, f"clf_{i}.h5")
|
model_path = os.path.join(args.model_path, f"clf_{i}.h5")
|
||||||
train_log_path = os.path.join(args.model_path, f"train_{i}.log.csv")
|
train_log_path = os.path.join(args.model_path, f"train_{i}.log.csv")
|
||||||
@ -238,23 +262,6 @@ def main_train(param=None):
|
|||||||
verbose=False))
|
verbose=False))
|
||||||
custom_metrics = models.get_metric_functions()
|
custom_metrics = models.get_metric_functions()
|
||||||
|
|
||||||
# custom class or sample weights
|
|
||||||
if args.class_weights:
|
|
||||||
logger.info("class weights: compute custom weights")
|
|
||||||
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
|
|
||||||
logger.info(custom_class_weights)
|
|
||||||
else:
|
|
||||||
logger.info("class weights: set default")
|
|
||||||
custom_class_weights = None
|
|
||||||
|
|
||||||
if args.sample_weights:
|
|
||||||
logger.info("class weights: compute custom weights")
|
|
||||||
custom_sample_weights = get_custom_sample_weights(client_tr.value, server_tr)
|
|
||||||
logger.info(custom_class_weights)
|
|
||||||
else:
|
|
||||||
logger.info("class weights: set default")
|
|
||||||
custom_sample_weights = None
|
|
||||||
|
|
||||||
logger.info(f"Generator model with params: {param}")
|
logger.info(f"Generator model with params: {param}")
|
||||||
model = models.get_models_by_params(param)
|
model = models.get_models_by_params(param)
|
||||||
|
|
||||||
@ -372,10 +379,14 @@ def main_retrain():
|
|||||||
|
|
||||||
|
|
||||||
def main_test():
|
def main_test():
|
||||||
logger.info("start test: load data")
|
logger.info("load test data")
|
||||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
|
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
|
||||||
|
logger.info("load test domains")
|
||||||
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
||||||
|
|
||||||
|
def get_dir(path):
|
||||||
|
return os.path.split(os.path.normpath(path))
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
for model_path in args.model_paths:
|
for model_path in args.model_paths:
|
||||||
file = get_dir(model_path)[1]
|
file = get_dir(model_path)[1]
|
||||||
@ -398,8 +409,8 @@ def main_test():
|
|||||||
|
|
||||||
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||||
results["domain_embds"] = domain_embeddings
|
results["domain_embds"] = domain_embeddings
|
||||||
|
# store results every round - safety first!
|
||||||
dataset.save_predictions(get_dir(model_path)[0], results)
|
dataset.save_predictions(get_dir(model_path)[0], results)
|
||||||
|
|
||||||
|
|
||||||
def main_visualization():
|
def main_visualization():
|
||||||
|
@ -1,8 +1,17 @@
|
|||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
from keras.models import Model
|
from keras.models import Model
|
||||||
|
|
||||||
from . import networks
|
from . import networks
|
||||||
from .metrics import *
|
from .metrics import *
|
||||||
|
|
||||||
|
NetworkParameters = namedtuple("NetworkParameters", [
|
||||||
|
"type", "flow_features", "window_size", "domain_length", "output",
|
||||||
|
"embedding_size",
|
||||||
|
"domain_filter", "domain_kernel", "domain_dense", "domain_dropout",
|
||||||
|
"main_filter", "main_kernel", "main_dense", "main_dropout",
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
def create_model(model, output_type):
|
def create_model(model, output_type):
|
||||||
if output_type == "both":
|
if output_type == "both":
|
||||||
@ -14,6 +23,7 @@ def create_model(model, output_type):
|
|||||||
|
|
||||||
|
|
||||||
def get_models_by_params(params: dict):
|
def get_models_by_params(params: dict):
|
||||||
|
K.clear_session()
|
||||||
# decomposing param section
|
# decomposing param section
|
||||||
# mainly embedding model
|
# mainly embedding model
|
||||||
network_type = params.get("type")
|
network_type = params.get("type")
|
||||||
|
Loading…
Reference in New Issue
Block a user