refactor training - separate staggered training; make differences as small as possible

This commit is contained in:
René Knaebel 2017-09-12 08:36:23 +02:00
parent 6ce8fb464f
commit 7f49021a63
2 changed files with 50 additions and 69 deletions

116
main.py
View File

@ -156,19 +156,37 @@ def main_train(param=None):
logger.info("class weights: set default")
custom_class_weights = None
if not param:
param = PARAMS
logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param)
callbacks.append(LambdaCallback(
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
)
model = create_model(model, args.model_output)
new_model = create_model(new_model, args.model_output)
if args.model_type in ("inter", "staggered"):
server_tr = np.expand_dims(server_windows_tr, 2)
model = new_model
if args.model_output == "both":
labels = {"client": client_tr, "server": server_tr}
loss_weights = {"client": 1.0, "server": 1.0}
elif args.model_output == "client":
labels = {"client": client_tr}
loss_weights = {"client": 1.0}
elif args.model_output == "server":
labels = {"server": server_tr}
loss_weights = {"server": 1.0}
else:
raise ValueError("unknown model output")
logger.info(f"select model: {args.model_type}")
if args.model_type == "staggered":
if not param:
param = PARAMS
logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param)
model = create_model(new_model, args.model_output)
server_tr = np.expand_dims(server_windows_tr, 2)
logger.info("compile and train model")
embedding.summary()
model.summary()
logger.info("compile and pre-train server model")
logger.info(model.get_config())
model.compile(optimizer='adam',
@ -183,66 +201,30 @@ def main_train(param=None):
shuffle=True,
validation_split=0.2,
class_weight=custom_class_weights)
logger.info("fix server model")
model.get_layer("domain_cnn").trainable = False
model.get_layer("dense_server").trainable = False
model.get_layer("server").trainable = False
model.compile(optimizer='adam',
loss='binary_crossentropy',
loss_weights={"client": 1.0, "server": 0.0},
metrics=['accuracy'] + custom_metrics)
loss_weights = {"client": 1.0, "server": 0.0}
model.summary()
callbacks.append(LambdaCallback(
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
)
model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr},
{"client": client_tr, "server": server_tr},
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
shuffle=True,
class_weight=custom_class_weights)
logger.info("compile and train model")
embedding.summary()
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
loss_weights=loss_weights,
metrics=['accuracy'] + custom_metrics)
else:
if not param:
param = PARAMS
logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param)
model = create_model(model, args.model_output)
new_model = create_model(new_model, args.model_output)
if args.model_type == "inter":
server_tr = np.expand_dims(server_windows_tr, 2)
model = new_model
logger.info("compile and train model")
embedding.summary()
model.summary()
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'] + custom_metrics)
if args.model_output == "both":
labels = [client_tr, server_tr]
elif args.model_output == "client":
labels = [client_tr]
elif args.model_output == "server":
labels = [server_tr]
else:
raise ValueError("unknown model output")
callbacks.append(LambdaCallback(
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
)
model.fit([domain_tr, flow_tr],
labels,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
shuffle=True,
validation_split=0.3,
class_weight=custom_class_weights)
model.summary()
model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr},
labels,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
shuffle=True,
validation_split=0.2,
class_weight=custom_class_weights)
def main_test():

View File

@ -3,7 +3,6 @@ from collections import namedtuple
import keras
from keras.engine import Input, Model as KerasModel
from keras.layers import Activation, Conv1D, Dense, Dropout, Embedding, GlobalMaxPooling1D, TimeDistributed
from keras.regularizers import l2
import dataset
@ -58,7 +57,7 @@ def get_model(cnnDropout, flow_features, domain_features, window_size, domain_le
# remove temporal dimension by global max pooling
y = GlobalMaxPooling1D()(y)
y = Dropout(cnnDropout)(y)
y = Dense(dense_dim, kernel_regularizer=l2(0.1), activation='relu')(y)
y = Dense(dense_dim, activation='relu')(y)
out_client = Dense(1, activation='sigmoid', name="client")(y)
out_server = Dense(1, activation='sigmoid', name="server")(y)