fix staggered training

This commit is contained in:
René Knaebel 2017-10-06 10:38:00 +02:00
parent 508667d1d0
commit 5741f8ee0e
2 changed files with 36 additions and 1 deletions

31
main.py

@ -201,6 +201,7 @@ def main_train(param=None):
loss_weights={"client": 0.0, "server": 1.0},
metrics=['accuracy'] + custom_metrics)
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
@ -208,6 +209,7 @@ def main_train(param=None):
logger.info("fix server model")
model.get_layer("domain_cnn").trainable = False
model.get_layer("domain_cnn").layer.trainable = False
model.get_layer("dense_server").trainable = False
model.get_layer("server").trainable = False
loss_weights = {"client": 1.0, "server": 0.0}
@ -649,6 +651,33 @@ def train_server_only():
callbacks=callbacks)
def test_server_only():
logger.info("start test: load data")
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
args.data,
args.domain_length,
args.window)
domain_val = domain_val.value.reshape(-1, 40)
flow_val = flow_val.value.reshape(-1, 3)
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
for model_args in get_model_args(args):
results = {}
logger.info(f"process model {model_args['model_path']}")
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects())
pred = clf_model.predict([domain_val, flow_val],
batch_size=args.batch_size,
verbose=1)
results["server_pred"] = pred
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
results["domain_embds"] = domain_embeddings
dataset.save_predictions(model_args["model_path"], results)
def main():
if "train" == args.mode:
main_train()
@ -668,6 +697,8 @@ def main():
plot_overall_result()
if "server" == args.mode:
train_server_only()
if "server_test" == args.mode:
test_server_only()
if __name__ == "__main__":

@ -37,6 +37,10 @@ def load_model(path, custom_objects=None):
except Exception:
# in some version i forgot to specify domain_cnn
# this bug fix is for certain compatibility
embd = clf.layers[1].layer
try:
embd = clf.layers[1].layer
except Exception:
embd = clf.get_layer("domain_cnn")
return embd, clf