extract weighting function

This commit is contained in:
René Knaebel 2017-11-10 10:18:13 +01:00
parent 9ce11e4db4
commit 4fc2f0c925
3 changed files with 45 additions and 20 deletions

View File

@ -263,8 +263,11 @@ def load_or_generate_domains(train_data, domain_length):
fn = f"{train_data}_domains.gz"
try:
logger.info(f"Load file {fn}.")
user_flow_df = pd.read_csv(fn)
logger.info(f"File successfully loaded.")
except FileNotFoundError:
logger.info(f"File {fn} not found, recreate.")
user_flow_df = get_user_flow_data(train_data)
# user_flow_df.reset_index(inplace=True)
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0,
@ -279,6 +282,7 @@ def load_or_generate_domains(train_data, domain_length):
user_flow_df.to_csv(fn, compression="gzip")
logger.info(f"Extract features from domains")
domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, domain_length))
domain_encs = np.stack(domain_encs)

51
main.py
View File

@ -197,6 +197,26 @@ def load_data(data, domain_length, window_size, model_type):
return domain_tr, flow_tr, client_tr, server_tr
def get_weighting(class_weights, sample_weights, client, server):
if class_weights:
logger.info("class weights: compute custom weights")
custom_class_weights = get_custom_class_weights(client, server)
logger.info(custom_class_weights)
else:
logger.info("class weights: set default")
custom_class_weights = None
if sample_weights:
logger.info("class weights: compute custom weights")
custom_sample_weights = get_custom_sample_weights(client, server)
logger.info(custom_sample_weights)
else:
logger.info("class weights: set default")
custom_sample_weights = None
return custom_class_weights, custom_sample_weights
def main_train(param=None):
logger.info(f"Create model path {args.model_path}")
exists_or_make_path(args.model_path)
@ -220,6 +240,10 @@ def main_train(param=None):
if not param:
param = PARAMS
# custom class or sample weights
custom_class_weights, custom_sample_weights = get_weighting(args.class_weights, args.sample_weights,
client_tr.value, server_tr)
for i in range(args.runs):
model_path = os.path.join(args.model_path, f"clf_{i}.h5")
train_log_path = os.path.join(args.model_path, f"train_{i}.log.csv")
@ -238,23 +262,6 @@ def main_train(param=None):
verbose=False))
custom_metrics = models.get_metric_functions()
# custom class or sample weights
if args.class_weights:
logger.info("class weights: compute custom weights")
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
logger.info(custom_class_weights)
else:
logger.info("class weights: set default")
custom_class_weights = None
if args.sample_weights:
logger.info("class weights: compute custom weights")
custom_sample_weights = get_custom_sample_weights(client_tr.value, server_tr)
logger.info(custom_class_weights)
else:
logger.info("class weights: set default")
custom_sample_weights = None
logger.info(f"Generator model with params: {param}")
model = models.get_models_by_params(param)
@ -372,10 +379,14 @@ def main_retrain():
def main_test():
logger.info("start test: load data")
logger.info("load test data")
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
logger.info("load test domains")
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
def get_dir(path):
return os.path.split(os.path.normpath(path))
results = {}
for model_path in args.model_paths:
file = get_dir(model_path)[1]
@ -398,8 +409,8 @@ def main_test():
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
results["domain_embds"] = domain_embeddings
dataset.save_predictions(get_dir(model_path)[0], results)
# store results every round - safety first!
dataset.save_predictions(get_dir(model_path)[0], results)
def main_visualization():

View File

@ -1,8 +1,17 @@
from collections import namedtuple
from keras.models import Model
from . import networks
from .metrics import *
NetworkParameters = namedtuple("NetworkParameters", [
"type", "flow_features", "window_size", "domain_length", "output",
"embedding_size",
"domain_filter", "domain_kernel", "domain_dense", "domain_dropout",
"main_filter", "main_kernel", "main_dense", "main_dropout",
])
def create_model(model, output_type):
if output_type == "both":
@ -14,6 +23,7 @@ def create_model(model, output_type):
def get_models_by_params(params: dict):
K.clear_session()
# decomposing param section
# mainly embedding model
network_type = params.get("type")