diff --git a/main.py b/main.py index e8cf011..0a8bb50 100644 --- a/main.py +++ b/main.py @@ -80,8 +80,8 @@ PARAMS = { # TODO: remove inner global params -def get_param_dist(size="small"): - if dist_type == "small": +def get_param_dist(dist_size="small"): + if dist_size == "small": return { # static params "type": [args.model_type], @@ -180,11 +180,7 @@ def train(parameters, features, labels): pass -def main_train(param=None): - logger.info(f"Create model path {args.model_path}") - exists_or_make_path(args.model_path) - logger.info(f"Use command line arguments: {args}") - +def load_data(data, domain_length, window_size, model_type): # data preparation domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data, args.data, @@ -193,110 +189,124 @@ def main_train(param=None): server_tr = np.max(server_windows_tr, axis=1) if args.model_type in ("inter", "staggered"): server_tr = np.expand_dims(server_windows_tr, 2) + return domain_tr, flow_tr, client_tr, server_tr + +def main_train(param=None): + logger.info(f"Create model path {args.model_path}") + exists_or_make_path(args.model_path) + logger.info(f"Use command line arguments: {args}") + + # data preparation + domain_tr, flow_tr, client_tr, server_tr = load_data(args.data, args.domain_length, + args.window, args.model_type) + # call hyperband if used if args.hyperband_results: logger.info("start hyperband parameter search") hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, 81, args.hyperband_results) - param = sorted(hyper_results, key=operator.itemgetter("loss"))[0] + param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"] logger.info(f"select params from result: {param}") - - # define training call backs - logger.info("define callbacks") - callbacks = [] - callbacks.append(ModelCheckpoint(filepath=args.clf_model, - monitor='loss', - verbose=False, - save_best_only=True)) - callbacks.append(CSVLogger(args.train_log)) - logger.info(f"Use early stopping: {args.stop_early}") - if args.stop_early: - callbacks.append(EarlyStopping(monitor='val_loss', - patience=5, - verbose=False)) - custom_metrics = models.get_metric_functions() - - # custom class or sample weights - if args.class_weights: - logger.info("class weights: compute custom weights") - custom_class_weights = get_custom_class_weights(client_tr.value, server_tr) - logger.info(custom_class_weights) - else: - logger.info("class weights: set default") - custom_class_weights = None - - if args.sample_weights: - logger.info("class weights: compute custom weights") - custom_sample_weights = get_custom_sample_weights(client_tr.value, server_tr) - logger.info(custom_class_weights) - else: - logger.info("class weights: set default") - custom_sample_weights = None - if not param: param = PARAMS - logger.info(f"Generator model with params: {param}") - embedding, model, new_model = models.get_models_by_params(param) - - model = create_model(model, args.model_output) - new_model = create_model(new_model, args.model_output) - - if args.model_type in ("inter", "staggered"): - model = new_model - - features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value} - if args.model_output == "both": - labels = {"client": client_tr.value, "server": server_tr} - loss_weights = {"client": 1.0, "server": 1.0} - elif args.model_output == "client": - labels = {"client": client_tr.value} - loss_weights = {"client": 1.0} - elif args.model_output == "server": - labels = {"server": server_tr} - loss_weights = {"server": 1.0} - else: - raise ValueError("unknown model output") - - logger.info(f"select model: {args.model_type}") - if args.model_type == "staggered": - logger.info("compile and pre-train server model") + + for i in range(20): + model_path = os.path.join(args.model_path, f"clf_{i}.h5") + train_log_path = os.path.join(args.model_path, "train_{i}.log.csv") + # define training call backs + logger.info("define callbacks") + callbacks = [] + callbacks.append(ModelCheckpoint(filepath=model_path, + monitor='loss', + verbose=False, + save_best_only=True)) + callbacks.append(CSVLogger(train_log_path)) + logger.info(f"Use early stopping: {args.stop_early}") + if args.stop_early: + callbacks.append(EarlyStopping(monitor='val_loss', + patience=5, + verbose=False)) + custom_metrics = models.get_metric_functions() + + # custom class or sample weights + if args.class_weights: + logger.info("class weights: compute custom weights") + custom_class_weights = get_custom_class_weights(client_tr.value, server_tr) + logger.info(custom_class_weights) + else: + logger.info("class weights: set default") + custom_class_weights = None + + if args.sample_weights: + logger.info("class weights: compute custom weights") + custom_sample_weights = get_custom_sample_weights(client_tr.value, server_tr) + logger.info(custom_class_weights) + else: + logger.info("class weights: set default") + custom_sample_weights = None + + logger.info(f"Generator model with params: {param}") + embedding, model, new_model = models.get_models_by_params(param) + + model = create_model(model, args.model_output) + new_model = create_model(new_model, args.model_output) + + if args.model_type in ("inter", "staggered"): + model = new_model + + features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value} + if args.model_output == "both": + labels = {"client": client_tr.value, "server": server_tr} + loss_weights = {"client": 1.0, "server": 1.0} + elif args.model_output == "client": + labels = {"client": client_tr.value} + loss_weights = {"client": 1.0} + elif args.model_output == "server": + labels = {"server": server_tr} + loss_weights = {"server": 1.0} + else: + raise ValueError("unknown model output") + + logger.info(f"select model: {args.model_type}") + if args.model_type == "staggered": + logger.info("compile and pre-train server model") + logger.info(model.get_config()) + + model.compile(optimizer='adam', + loss='binary_crossentropy', + loss_weights={"client": 0.0, "server": 1.0}, + metrics=['accuracy'] + custom_metrics) + + model.summary() + model.fit(features, labels, + batch_size=args.batch_size, + epochs=args.epochs, + class_weight=custom_class_weights, + sample_weight=custom_sample_weights) + + logger.info("fix server model") + model.get_layer("domain_cnn").trainable = False + model.get_layer("domain_cnn").layer.trainable = False + model.get_layer("dense_server").trainable = False + model.get_layer("server").trainable = False + loss_weights = {"client": 1.0, "server": 0.0} + + logger.info("compile and train model") + embedding.summary() logger.info(model.get_config()) - model.compile(optimizer='adam', loss='binary_crossentropy', - loss_weights={"client": 0.0, "server": 1.0}, + loss_weights=loss_weights, metrics=['accuracy'] + custom_metrics) - + model.summary() model.fit(features, labels, batch_size=args.batch_size, epochs=args.epochs, + callbacks=callbacks, class_weight=custom_class_weights, sample_weight=custom_sample_weights) - logger.info("fix server model") - model.get_layer("domain_cnn").trainable = False - model.get_layer("domain_cnn").layer.trainable = False - model.get_layer("dense_server").trainable = False - model.get_layer("server").trainable = False - loss_weights = {"client": 1.0, "server": 0.0} - - logger.info("compile and train model") - embedding.summary() - logger.info(model.get_config()) - model.compile(optimizer='adam', - loss='binary_crossentropy', - loss_weights=loss_weights, - metrics=['accuracy'] + custom_metrics) - - model.summary() - model.fit(features, labels, - batch_size=args.batch_size, - epochs=args.epochs, - callbacks=callbacks, - class_weight=custom_class_weights, - sample_weight=custom_sample_weights) - def main_retrain(): source = os.path.join(args.model_source, "clf.h5") @@ -470,15 +480,6 @@ def main_visualization(): normalize=True, title="User Confusion Matrix") -# plot_embedding(args.model_path, results["domain_embds"], args.data, args.domain_length) - - -# def plot_embedding(model_path, domain_embedding, data, domain_length): -# logger.info("visualize embedding") -# domain_encs, labels = dataset.load_or_generate_domains(data, domain_length) -# visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd") - - def main_visualize_all(): _, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data, args.data, @@ -623,17 +624,17 @@ def main_beta(): val = server_val.value.max(axis=1) data["server_pred"] = server.flatten() data["server_val"] = val.flatten() - + if res["server_pred"].flatten().shape == server_val.value.flatten().shape: df_server = pd.DataFrame(data={ "server_pred": res["server_pred"].flatten(), "domain": domains, "server_val": server_val.value.flatten() }) - + res = pd.DataFrame(data=data) res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3) - + return res, df_server client_preds = [] @@ -706,6 +707,7 @@ def main_beta(): import matplotlib.pyplot as plt + def plot_overall_result(): path, model_prefix = os.path.split(os.path.normpath(args.output_prefix)) try: @@ -814,9 +816,8 @@ def main_stats2(): print(f"% {vis}") print(df.round(4).to_latex()) print() - - - + + def main(): if "train" == args.mode: main_train()