fix validation split removal by loading h5data into memory
This commit is contained in:
parent
ec5a1101be
commit
f2845e635e
16
main.py
16
main.py
@ -136,7 +136,6 @@ def main_train(param=None):
|
|||||||
callbacks = []
|
callbacks = []
|
||||||
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
||||||
monitor='loss',
|
monitor='loss',
|
||||||
# monitor='val_loss',
|
|
||||||
verbose=False,
|
verbose=False,
|
||||||
save_best_only=True))
|
save_best_only=True))
|
||||||
callbacks.append(CSVLogger(args.train_log))
|
callbacks.append(CSVLogger(args.train_log))
|
||||||
@ -173,11 +172,12 @@ def main_train(param=None):
|
|||||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||||
model = new_model
|
model = new_model
|
||||||
|
|
||||||
|
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
|
||||||
if args.model_output == "both":
|
if args.model_output == "both":
|
||||||
labels = {"client": client_tr, "server": server_tr}
|
labels = {"client": client_tr.value, "server": server_tr}
|
||||||
loss_weights = {"client": 1.0, "server": 1.0}
|
loss_weights = {"client": 1.0, "server": 1.0}
|
||||||
elif args.model_output == "client":
|
elif args.model_output == "client":
|
||||||
labels = {"client": client_tr}
|
labels = {"client": client_tr.value}
|
||||||
loss_weights = {"client": 1.0}
|
loss_weights = {"client": 1.0}
|
||||||
elif args.model_output == "server":
|
elif args.model_output == "server":
|
||||||
labels = {"server": server_tr}
|
labels = {"server": server_tr}
|
||||||
@ -195,12 +195,9 @@ def main_train(param=None):
|
|||||||
loss_weights={"client": 0.0, "server": 1.0},
|
loss_weights={"client": 0.0, "server": 1.0},
|
||||||
metrics=['accuracy'] + custom_metrics)
|
metrics=['accuracy'] + custom_metrics)
|
||||||
|
|
||||||
model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr},
|
model.fit(features, labels,
|
||||||
{"client": client_tr, "server": server_tr},
|
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
shuffle=True,
|
|
||||||
# validation_split=0.2,
|
|
||||||
class_weight=custom_class_weights)
|
class_weight=custom_class_weights)
|
||||||
|
|
||||||
logger.info("fix server model")
|
logger.info("fix server model")
|
||||||
@ -218,13 +215,10 @@ def main_train(param=None):
|
|||||||
metrics=['accuracy'] + custom_metrics)
|
metrics=['accuracy'] + custom_metrics)
|
||||||
|
|
||||||
model.summary()
|
model.summary()
|
||||||
model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr},
|
model.fit(features, labels,
|
||||||
labels,
|
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
shuffle=True,
|
|
||||||
# validation_split=0.2,
|
|
||||||
class_weight=custom_class_weights)
|
class_weight=custom_class_weights)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user