From 7f1d13658fad00acc40bac2c64746da04935de68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Thu, 3 Aug 2017 09:08:24 +0200 Subject: [PATCH] store domain embeddings while test main --- Makefile | 10 +++++----- dataset.py | 6 ++++-- main.py | 14 +++++++++----- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/Makefile b/Makefile index cf52d24..cdf029f 100644 --- a/Makefile +++ b/Makefile @@ -1,16 +1,16 @@ run: - python3 main.py --modes train --train data/rk_mini.csv.gz --model results/test --epochs 10 \ + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test --epochs 10 \ --hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights run_new: - python3 main.py --modes train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 \ + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 \ --hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights --new_model test: - python3 main.py --modes test --batch 128 --model results/test --test data/rk_mini.csv.gz + python3 main.py --mode test --batch 128 --model results/test --test data/rk_mini.csv.gz fancy: - python3 main.py --modes fancy --batch 128 --model results/test --test data/rk_mini.csv.gz + python3 main.py --mode fancy --batch 128 --model results/test --test data/rk_mini.csv.gz hyper: - python3 main.py --modes hyperband --batch 64 --train data/rk_data.csv.gz + python3 main.py --mode hyperband --batch 64 --train data/rk_data.csv.gz diff --git a/dataset.py b/dataset.py index ccebf93..2fe5fd9 100644 --- a/dataset.py +++ b/dataset.py @@ -248,13 +248,15 @@ def load_or_generate_domains(train_data, domain_length): return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix() -def save_predictions(path, c_pred, s_pred): +def save_predictions(path, c_pred, s_pred, embd, labels): f = h5py.File(path, "w") f.create_dataset("client", data=c_pred) f.create_dataset("server", data=s_pred) + f.create_dataset("embedding", data=embd) + f.create_dataset("labels", data=labels) f.close() def load_predictions(path): f = h5py.File(path, "r") - return f["client"], f["server"] + return f["client"], f["server"], f["embedding"], f["labels"] diff --git a/main.py b/main.py index d280d9e..7374dc8 100644 --- a/main.py +++ b/main.py @@ -194,7 +194,13 @@ def main_test(): else: c_pred = np.zeros(0) s_pred = pred - dataset.save_predictions(args.future_prediction, c_pred, s_pred) + + model = load_model(args.embedding_model) + domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) + domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1) + + dataset.save_predictions(args.future_prediction, c_pred, s_pred, domain_embedding, labels) + def main_visualization(): @@ -212,7 +218,7 @@ def main_visualization(): except Exception as e: logger.warning(f"could not generate training curves: {e}") - client_pred, server_pred = dataset.load_predictions(args.future_prediction) + client_pred, server_pred, domain_embedding, labels = dataset.load_predictions(args.future_prediction) client_pred, server_pred = client_pred.value, server_pred.value logger.info("plot pr curve") visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path)) @@ -229,9 +235,7 @@ def main_visualization(): # "{}/server_cov.png".format(args.model_path), # normalize=False, title="Server Confusion Matrix") logger.info("visualize embedding") - model = load_model(args.embedding_model) - domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) - domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1) + visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))