diff --git a/main.py b/main.py index 5851314..ce27efc 100644 --- a/main.py +++ b/main.py @@ -136,7 +136,6 @@ def main_train(param=None): callbacks = [] callbacks.append(ModelCheckpoint(filepath=args.clf_model, monitor='loss', - # monitor='val_loss', verbose=False, save_best_only=True)) callbacks.append(CSVLogger(args.train_log)) @@ -173,11 +172,12 @@ def main_train(param=None): server_tr = np.expand_dims(server_windows_tr, 2) model = new_model + features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value} if args.model_output == "both": - labels = {"client": client_tr, "server": server_tr} + labels = {"client": client_tr.value, "server": server_tr} loss_weights = {"client": 1.0, "server": 1.0} elif args.model_output == "client": - labels = {"client": client_tr} + labels = {"client": client_tr.value} loss_weights = {"client": 1.0} elif args.model_output == "server": labels = {"server": server_tr} @@ -195,12 +195,9 @@ def main_train(param=None): loss_weights={"client": 0.0, "server": 1.0}, metrics=['accuracy'] + custom_metrics) - model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr}, - {"client": client_tr, "server": server_tr}, + model.fit(features, labels, batch_size=args.batch_size, epochs=args.epochs, - shuffle=True, - # validation_split=0.2, class_weight=custom_class_weights) logger.info("fix server model") @@ -218,13 +215,10 @@ def main_train(param=None): metrics=['accuracy'] + custom_metrics) model.summary() - model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr}, - labels, + model.fit(features, labels, batch_size=args.batch_size, epochs=args.epochs, callbacks=callbacks, - shuffle=True, - # validation_split=0.2, class_weight=custom_class_weights)