fix validation split removal by loading h5data into memory

This commit is contained in:
René Knaebel 2017-09-17 09:56:18 +02:00
parent ec5a1101be
commit f2845e635e

16
main.py
View File

@ -136,7 +136,6 @@ def main_train(param=None):
callbacks = [] callbacks = []
callbacks.append(ModelCheckpoint(filepath=args.clf_model, callbacks.append(ModelCheckpoint(filepath=args.clf_model,
monitor='loss', monitor='loss',
# monitor='val_loss',
verbose=False, verbose=False,
save_best_only=True)) save_best_only=True))
callbacks.append(CSVLogger(args.train_log)) callbacks.append(CSVLogger(args.train_log))
@ -173,11 +172,12 @@ def main_train(param=None):
server_tr = np.expand_dims(server_windows_tr, 2) server_tr = np.expand_dims(server_windows_tr, 2)
model = new_model model = new_model
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
if args.model_output == "both": if args.model_output == "both":
labels = {"client": client_tr, "server": server_tr} labels = {"client": client_tr.value, "server": server_tr}
loss_weights = {"client": 1.0, "server": 1.0} loss_weights = {"client": 1.0, "server": 1.0}
elif args.model_output == "client": elif args.model_output == "client":
labels = {"client": client_tr} labels = {"client": client_tr.value}
loss_weights = {"client": 1.0} loss_weights = {"client": 1.0}
elif args.model_output == "server": elif args.model_output == "server":
labels = {"server": server_tr} labels = {"server": server_tr}
@ -195,12 +195,9 @@ def main_train(param=None):
loss_weights={"client": 0.0, "server": 1.0}, loss_weights={"client": 0.0, "server": 1.0},
metrics=['accuracy'] + custom_metrics) metrics=['accuracy'] + custom_metrics)
model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr}, model.fit(features, labels,
{"client": client_tr, "server": server_tr},
batch_size=args.batch_size, batch_size=args.batch_size,
epochs=args.epochs, epochs=args.epochs,
shuffle=True,
# validation_split=0.2,
class_weight=custom_class_weights) class_weight=custom_class_weights)
logger.info("fix server model") logger.info("fix server model")
@ -218,13 +215,10 @@ def main_train(param=None):
metrics=['accuracy'] + custom_metrics) metrics=['accuracy'] + custom_metrics)
model.summary() model.summary()
model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr}, model.fit(features, labels,
labels,
batch_size=args.batch_size, batch_size=args.batch_size,
epochs=args.epochs, epochs=args.epochs,
callbacks=callbacks, callbacks=callbacks,
shuffle=True,
# validation_split=0.2,
class_weight=custom_class_weights) class_weight=custom_class_weights)