fix staggered training
This commit is contained in:
parent
508667d1d0
commit
5741f8ee0e
31
main.py
31
main.py
@ -201,6 +201,7 @@ def main_train(param=None):
|
||||
loss_weights={"client": 0.0, "server": 1.0},
|
||||
metrics=['accuracy'] + custom_metrics)
|
||||
|
||||
model.summary()
|
||||
model.fit(features, labels,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
@ -208,6 +209,7 @@ def main_train(param=None):
|
||||
|
||||
logger.info("fix server model")
|
||||
model.get_layer("domain_cnn").trainable = False
|
||||
model.get_layer("domain_cnn").layer.trainable = False
|
||||
model.get_layer("dense_server").trainable = False
|
||||
model.get_layer("server").trainable = False
|
||||
loss_weights = {"client": 1.0, "server": 0.0}
|
||||
@ -649,6 +651,33 @@ def train_server_only():
|
||||
callbacks=callbacks)
|
||||
|
||||
|
||||
def test_server_only():
|
||||
logger.info("start test: load data")
|
||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
|
||||
args.data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
domain_val = domain_val.value.reshape(-1, 40)
|
||||
flow_val = flow_val.value.reshape(-1, 3)
|
||||
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
||||
|
||||
for model_args in get_model_args(args):
|
||||
results = {}
|
||||
logger.info(f"process model {model_args['model_path']}")
|
||||
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects())
|
||||
|
||||
pred = clf_model.predict([domain_val, flow_val],
|
||||
batch_size=args.batch_size,
|
||||
verbose=1)
|
||||
|
||||
results["server_pred"] = pred
|
||||
|
||||
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||
results["domain_embds"] = domain_embeddings
|
||||
|
||||
dataset.save_predictions(model_args["model_path"], results)
|
||||
|
||||
|
||||
def main():
|
||||
if "train" == args.mode:
|
||||
main_train()
|
||||
@ -668,6 +697,8 @@ def main():
|
||||
plot_overall_result()
|
||||
if "server" == args.mode:
|
||||
train_server_only()
|
||||
if "server_test" == args.mode:
|
||||
test_server_only()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user