add server classification model

This commit is contained in:
René Knaebel 2017-10-05 15:26:53 +02:00
parent 345afbaef5
commit 508667d1d0
4 changed files with 104 additions and 4 deletions

55
main.py
View File

@ -596,6 +596,59 @@ def plot_overall_result():
visualize.plot_save(f"{path}/error_bars_{cat}.png")
def train_server_only():
logger.info(f"Create model path {args.model_path}")
exists_or_make_path(args.model_path)
logger.info(f"Use command line arguments: {args}")
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
args.data,
args.domain_length,
args.window)
domain_tr = domain_tr.value.reshape(-1, 40)
flow_tr = flow_tr.value.reshape(-1, 3)
server_tr = server_windows_tr.value.reshape(-1)
logger.info("define callbacks")
callbacks = []
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
monitor='loss',
verbose=False,
save_best_only=True))
callbacks.append(CSVLogger(args.train_log))
logger.info(f"Use early stopping: {args.stop_early}")
if args.stop_early:
callbacks.append(EarlyStopping(monitor='val_loss',
patience=5,
verbose=False))
custom_metrics = models.get_metric_functions()
model = models.get_server_model_by_params(params=PARAMS)
features = {"ipt_domains": domain_tr, "ipt_flows": flow_tr}
if args.model_output == "both":
labels = {"client": client_tr, "server": server_tr}
elif args.model_output == "client":
labels = {"client": client_tr}
elif args.model_output == "server":
labels = {"server": server_tr}
else:
raise ValueError("unknown model output")
logger.info("compile and train model")
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'] + custom_metrics)
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
validation_split=0.2,
callbacks=callbacks)
def main():
if "train" == args.mode:
main_train()
@ -613,6 +666,8 @@ def main():
main_beta()
if "all_beta" == args.mode:
plot_overall_result()
if "server" == args.mode:
train_server_only()
if __name__ == "__main__":

View File

@ -7,14 +7,14 @@ from . import flat_2, pauls_networks, renes_networks
def get_models_by_params(params: dict):
# decomposing param section
# mainly embedding model
network_type = params.get("type")
# network_type = params.get("type")
network_depth = params.get("depth")
embedding_size = params.get("embedding")
input_length = params.get("input_length")
filter_embedding = params.get("filter_embedding")
kernel_embedding = params.get("kernel_embedding")
hidden_embedding = params.get("dense_embedding")
dropout = params.get("dropout")
# dropout = params.get("dropout")
# mainly prediction model
flow_features = params.get("flow_features")
window_size = params.get("window_size")
@ -44,7 +44,36 @@ def get_models_by_params(params: dict):
return embedding_model, old_model, new_model
def get_metrics():
def get_server_model_by_params(params: dict):
# decomposing param section
# mainly embedding model
network_depth = params.get("depth")
embedding_size = params.get("embedding")
input_length = params.get("input_length")
filter_embedding = params.get("filter_embedding")
kernel_embedding = params.get("kernel_embedding")
hidden_embedding = params.get("dense_embedding")
# mainly prediction model
flow_features = params.get("flow_features")
domain_length = params.get("domain_length")
dense_dim = params.get("dense_main")
# create models
if network_depth == "flat1":
networks = pauls_networks
elif network_depth == "flat2":
networks = flat_2
elif network_depth == "deep1":
networks = renes_networks
else:
raise Exception("network not found")
embedding_model = networks.get_embedding(embedding_size, input_length, filter_embedding, kernel_embedding,
hidden_embedding, 0.5)
return networks.get_server_model(flow_features, domain_length, dense_dim, embedding_model)
def get_custom_objects():
return dict([
("precision", precision),
("recall", recall),

View File

@ -3,7 +3,8 @@ from collections import namedtuple
import keras
from keras.activations import elu
from keras.engine import Input, Model as KerasModel
from keras.layers import Conv1D, Dense, Dropout, Embedding, GlobalAveragePooling1D, GlobalMaxPooling1D, TimeDistributed
from keras.layers import BatchNormalization, Conv1D, Dense, Dropout, Embedding, GlobalAveragePooling1D, \
GlobalMaxPooling1D, TimeDistributed
import dataset
@ -40,6 +41,8 @@ def get_model(cnnDropout, flow_features, domain_features, window_size, domain_le
ipt_domains = Input(shape=(window_size, domain_length), name="ipt_domains")
encoded = TimeDistributed(cnn, name="domain_cnn")(ipt_domains)
ipt_flows = Input(shape=(window_size, flow_features), name="ipt_flows")
ipt_flows = BatchNormalization()(ipt_flows)
ipt_flows = Dense(dense_dim, activation=selu)(ipt_flows)
merged = keras.layers.concatenate([encoded, ipt_flows], -1)
# CNN processing a small slides of flow windows
y = Conv1D(cnn_dims,

View File

@ -89,3 +89,16 @@ def get_new_model(dropout, flow_features, domain_features, window_size, domain_l
out_client = Dense(1, activation='sigmoid', name="client")(y)
return Model(ipt_domains, ipt_flows, out_client, out_server)
def get_server_model(flow_features, domain_length, dense_dim, cnn):
ipt_domains = Input(shape=(domain_length,), name="ipt_domains")
ipt_flows = Input(shape=(flow_features,), name="ipt_flows")
encoded = cnn(ipt_domains)
merged = keras.layers.concatenate([encoded, ipt_flows], -1)
y = Dense(dense_dim,
activation="relu",
name="dense_server")(merged)
out_server = Dense(1, activation="sigmoid", name="server")(y)
return KerasModel(inputs=[ipt_domains, ipt_flows], outputs=out_server)