fix staggered training
This commit is contained in:
parent
508667d1d0
commit
5741f8ee0e
31
main.py
31
main.py
@ -201,6 +201,7 @@ def main_train(param=None):
|
|||||||
loss_weights={"client": 0.0, "server": 1.0},
|
loss_weights={"client": 0.0, "server": 1.0},
|
||||||
metrics=['accuracy'] + custom_metrics)
|
metrics=['accuracy'] + custom_metrics)
|
||||||
|
|
||||||
|
model.summary()
|
||||||
model.fit(features, labels,
|
model.fit(features, labels,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
@ -208,6 +209,7 @@ def main_train(param=None):
|
|||||||
|
|
||||||
logger.info("fix server model")
|
logger.info("fix server model")
|
||||||
model.get_layer("domain_cnn").trainable = False
|
model.get_layer("domain_cnn").trainable = False
|
||||||
|
model.get_layer("domain_cnn").layer.trainable = False
|
||||||
model.get_layer("dense_server").trainable = False
|
model.get_layer("dense_server").trainable = False
|
||||||
model.get_layer("server").trainable = False
|
model.get_layer("server").trainable = False
|
||||||
loss_weights = {"client": 1.0, "server": 0.0}
|
loss_weights = {"client": 1.0, "server": 0.0}
|
||||||
@ -649,6 +651,33 @@ def train_server_only():
|
|||||||
callbacks=callbacks)
|
callbacks=callbacks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_only():
|
||||||
|
logger.info("start test: load data")
|
||||||
|
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
|
||||||
|
args.data,
|
||||||
|
args.domain_length,
|
||||||
|
args.window)
|
||||||
|
domain_val = domain_val.value.reshape(-1, 40)
|
||||||
|
flow_val = flow_val.value.reshape(-1, 3)
|
||||||
|
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
||||||
|
|
||||||
|
for model_args in get_model_args(args):
|
||||||
|
results = {}
|
||||||
|
logger.info(f"process model {model_args['model_path']}")
|
||||||
|
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects())
|
||||||
|
|
||||||
|
pred = clf_model.predict([domain_val, flow_val],
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
verbose=1)
|
||||||
|
|
||||||
|
results["server_pred"] = pred
|
||||||
|
|
||||||
|
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||||
|
results["domain_embds"] = domain_embeddings
|
||||||
|
|
||||||
|
dataset.save_predictions(model_args["model_path"], results)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if "train" == args.mode:
|
if "train" == args.mode:
|
||||||
main_train()
|
main_train()
|
||||||
@ -668,6 +697,8 @@ def main():
|
|||||||
plot_overall_result()
|
plot_overall_result()
|
||||||
if "server" == args.mode:
|
if "server" == args.mode:
|
||||||
train_server_only()
|
train_server_only()
|
||||||
|
if "server_test" == args.mode:
|
||||||
|
test_server_only()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
4
utils.py
4
utils.py
@ -37,6 +37,10 @@ def load_model(path, custom_objects=None):
|
|||||||
except Exception:
|
except Exception:
|
||||||
# in some version i forgot to specify domain_cnn
|
# in some version i forgot to specify domain_cnn
|
||||||
# this bug fix is for certain compatibility
|
# this bug fix is for certain compatibility
|
||||||
|
try:
|
||||||
embd = clf.layers[1].layer
|
embd = clf.layers[1].layer
|
||||||
|
except Exception:
|
||||||
|
embd = clf.get_layer("domain_cnn")
|
||||||
|
|
||||||
|
|
||||||
return embd, clf
|
return embd, clf
|
||||||
|
Loading…
Reference in New Issue
Block a user