train multiple models at once

This commit is contained in:
René Knaebel 2017-11-04 17:58:21 +01:00
parent 88e3eda595
commit 14fef66a55
1 changed files with 108 additions and 107 deletions

215
main.py
View File

@ -80,8 +80,8 @@ PARAMS = {
# TODO: remove inner global params # TODO: remove inner global params
def get_param_dist(size="small"): def get_param_dist(dist_size="small"):
if dist_type == "small": if dist_size == "small":
return { return {
# static params # static params
"type": [args.model_type], "type": [args.model_type],
@ -180,11 +180,7 @@ def train(parameters, features, labels):
pass pass
def main_train(param=None): def load_data(data, domain_length, window_size, model_type):
logger.info(f"Create model path {args.model_path}")
exists_or_make_path(args.model_path)
logger.info(f"Use command line arguments: {args}")
# data preparation # data preparation
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data, domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
args.data, args.data,
@ -193,110 +189,124 @@ def main_train(param=None):
server_tr = np.max(server_windows_tr, axis=1) server_tr = np.max(server_windows_tr, axis=1)
if args.model_type in ("inter", "staggered"): if args.model_type in ("inter", "staggered"):
server_tr = np.expand_dims(server_windows_tr, 2) server_tr = np.expand_dims(server_windows_tr, 2)
return domain_tr, flow_tr, client_tr, server_tr
def main_train(param=None):
logger.info(f"Create model path {args.model_path}")
exists_or_make_path(args.model_path)
logger.info(f"Use command line arguments: {args}")
# data preparation
domain_tr, flow_tr, client_tr, server_tr = load_data(args.data, args.domain_length,
args.window, args.model_type)
# call hyperband if used # call hyperband if used
if args.hyperband_results: if args.hyperband_results:
logger.info("start hyperband parameter search") logger.info("start hyperband parameter search")
hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, 81, args.hyperband_results) hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, 81, args.hyperband_results)
param = sorted(hyper_results, key=operator.itemgetter("loss"))[0] param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"]
logger.info(f"select params from result: {param}") logger.info(f"select params from result: {param}")
# define training call backs
logger.info("define callbacks")
callbacks = []
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
monitor='loss',
verbose=False,
save_best_only=True))
callbacks.append(CSVLogger(args.train_log))
logger.info(f"Use early stopping: {args.stop_early}")
if args.stop_early:
callbacks.append(EarlyStopping(monitor='val_loss',
patience=5,
verbose=False))
custom_metrics = models.get_metric_functions()
# custom class or sample weights
if args.class_weights:
logger.info("class weights: compute custom weights")
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
logger.info(custom_class_weights)
else:
logger.info("class weights: set default")
custom_class_weights = None
if args.sample_weights:
logger.info("class weights: compute custom weights")
custom_sample_weights = get_custom_sample_weights(client_tr.value, server_tr)
logger.info(custom_class_weights)
else:
logger.info("class weights: set default")
custom_sample_weights = None
if not param: if not param:
param = PARAMS param = PARAMS
logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param) for i in range(20):
model_path = os.path.join(args.model_path, f"clf_{i}.h5")
model = create_model(model, args.model_output) train_log_path = os.path.join(args.model_path, "train_{i}.log.csv")
new_model = create_model(new_model, args.model_output) # define training call backs
logger.info("define callbacks")
if args.model_type in ("inter", "staggered"): callbacks = []
model = new_model callbacks.append(ModelCheckpoint(filepath=model_path,
monitor='loss',
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value} verbose=False,
if args.model_output == "both": save_best_only=True))
labels = {"client": client_tr.value, "server": server_tr} callbacks.append(CSVLogger(train_log_path))
loss_weights = {"client": 1.0, "server": 1.0} logger.info(f"Use early stopping: {args.stop_early}")
elif args.model_output == "client": if args.stop_early:
labels = {"client": client_tr.value} callbacks.append(EarlyStopping(monitor='val_loss',
loss_weights = {"client": 1.0} patience=5,
elif args.model_output == "server": verbose=False))
labels = {"server": server_tr} custom_metrics = models.get_metric_functions()
loss_weights = {"server": 1.0}
else: # custom class or sample weights
raise ValueError("unknown model output") if args.class_weights:
logger.info("class weights: compute custom weights")
logger.info(f"select model: {args.model_type}") custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
if args.model_type == "staggered": logger.info(custom_class_weights)
logger.info("compile and pre-train server model") else:
logger.info("class weights: set default")
custom_class_weights = None
if args.sample_weights:
logger.info("class weights: compute custom weights")
custom_sample_weights = get_custom_sample_weights(client_tr.value, server_tr)
logger.info(custom_class_weights)
else:
logger.info("class weights: set default")
custom_sample_weights = None
logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param)
model = create_model(model, args.model_output)
new_model = create_model(new_model, args.model_output)
if args.model_type in ("inter", "staggered"):
model = new_model
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
if args.model_output == "both":
labels = {"client": client_tr.value, "server": server_tr}
loss_weights = {"client": 1.0, "server": 1.0}
elif args.model_output == "client":
labels = {"client": client_tr.value}
loss_weights = {"client": 1.0}
elif args.model_output == "server":
labels = {"server": server_tr}
loss_weights = {"server": 1.0}
else:
raise ValueError("unknown model output")
logger.info(f"select model: {args.model_type}")
if args.model_type == "staggered":
logger.info("compile and pre-train server model")
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
loss_weights={"client": 0.0, "server": 1.0},
metrics=['accuracy'] + custom_metrics)
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
class_weight=custom_class_weights,
sample_weight=custom_sample_weights)
logger.info("fix server model")
model.get_layer("domain_cnn").trainable = False
model.get_layer("domain_cnn").layer.trainable = False
model.get_layer("dense_server").trainable = False
model.get_layer("server").trainable = False
loss_weights = {"client": 1.0, "server": 0.0}
logger.info("compile and train model")
embedding.summary()
logger.info(model.get_config()) logger.info(model.get_config())
model.compile(optimizer='adam', model.compile(optimizer='adam',
loss='binary_crossentropy', loss='binary_crossentropy',
loss_weights={"client": 0.0, "server": 1.0}, loss_weights=loss_weights,
metrics=['accuracy'] + custom_metrics) metrics=['accuracy'] + custom_metrics)
model.summary() model.summary()
model.fit(features, labels, model.fit(features, labels,
batch_size=args.batch_size, batch_size=args.batch_size,
epochs=args.epochs, epochs=args.epochs,
callbacks=callbacks,
class_weight=custom_class_weights, class_weight=custom_class_weights,
sample_weight=custom_sample_weights) sample_weight=custom_sample_weights)
logger.info("fix server model")
model.get_layer("domain_cnn").trainable = False
model.get_layer("domain_cnn").layer.trainable = False
model.get_layer("dense_server").trainable = False
model.get_layer("server").trainable = False
loss_weights = {"client": 1.0, "server": 0.0}
logger.info("compile and train model")
embedding.summary()
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
loss_weights=loss_weights,
metrics=['accuracy'] + custom_metrics)
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
class_weight=custom_class_weights,
sample_weight=custom_sample_weights)
def main_retrain(): def main_retrain():
source = os.path.join(args.model_source, "clf.h5") source = os.path.join(args.model_source, "clf.h5")
@ -470,15 +480,6 @@ def main_visualization():
normalize=True, title="User Confusion Matrix") normalize=True, title="User Confusion Matrix")
# plot_embedding(args.model_path, results["domain_embds"], args.data, args.domain_length)
# def plot_embedding(model_path, domain_embedding, data, domain_length):
# logger.info("visualize embedding")
# domain_encs, labels = dataset.load_or_generate_domains(data, domain_length)
# visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd")
def main_visualize_all(): def main_visualize_all():
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data, _, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.data, args.data,
@ -623,17 +624,17 @@ def main_beta():
val = server_val.value.max(axis=1) val = server_val.value.max(axis=1)
data["server_pred"] = server.flatten() data["server_pred"] = server.flatten()
data["server_val"] = val.flatten() data["server_val"] = val.flatten()
if res["server_pred"].flatten().shape == server_val.value.flatten().shape: if res["server_pred"].flatten().shape == server_val.value.flatten().shape:
df_server = pd.DataFrame(data={ df_server = pd.DataFrame(data={
"server_pred": res["server_pred"].flatten(), "server_pred": res["server_pred"].flatten(),
"domain": domains, "domain": domains,
"server_val": server_val.value.flatten() "server_val": server_val.value.flatten()
}) })
res = pd.DataFrame(data=data) res = pd.DataFrame(data=data)
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3) res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
return res, df_server return res, df_server
client_preds = [] client_preds = []
@ -706,6 +707,7 @@ def main_beta():
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def plot_overall_result(): def plot_overall_result():
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix)) path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
try: try:
@ -814,9 +816,8 @@ def main_stats2():
print(f"% {vis}") print(f"% {vis}")
print(df.round(4).to_latex()) print(df.round(4).to_latex())
print() print()
def main(): def main():
if "train" == args.mode: if "train" == args.mode:
main_train() main_train()