store domain embeddings while test main

This commit is contained in:
René Knaebel 2017-08-03 09:08:24 +02:00
parent 452f9e0456
commit 7f1d13658f
3 changed files with 18 additions and 12 deletions

View File

@ -1,16 +1,16 @@
run:
python3 main.py --modes train --train data/rk_mini.csv.gz --model results/test --epochs 10 \
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test --epochs 10 \
--hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights
run_new:
python3 main.py --modes train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 \
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 \
--hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights --new_model
test:
python3 main.py --modes test --batch 128 --model results/test --test data/rk_mini.csv.gz
python3 main.py --mode test --batch 128 --model results/test --test data/rk_mini.csv.gz
fancy:
python3 main.py --modes fancy --batch 128 --model results/test --test data/rk_mini.csv.gz
python3 main.py --mode fancy --batch 128 --model results/test --test data/rk_mini.csv.gz
hyper:
python3 main.py --modes hyperband --batch 64 --train data/rk_data.csv.gz
python3 main.py --mode hyperband --batch 64 --train data/rk_data.csv.gz

View File

@ -248,13 +248,15 @@ def load_or_generate_domains(train_data, domain_length):
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix()
def save_predictions(path, c_pred, s_pred):
def save_predictions(path, c_pred, s_pred, embd, labels):
f = h5py.File(path, "w")
f.create_dataset("client", data=c_pred)
f.create_dataset("server", data=s_pred)
f.create_dataset("embedding", data=embd)
f.create_dataset("labels", data=labels)
f.close()
def load_predictions(path):
f = h5py.File(path, "r")
return f["client"], f["server"]
return f["client"], f["server"], f["embedding"], f["labels"]

14
main.py
View File

@ -194,7 +194,13 @@ def main_test():
else:
c_pred = np.zeros(0)
s_pred = pred
dataset.save_predictions(args.future_prediction, c_pred, s_pred)
model = load_model(args.embedding_model)
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
dataset.save_predictions(args.future_prediction, c_pred, s_pred, domain_embedding, labels)
def main_visualization():
@ -212,7 +218,7 @@ def main_visualization():
except Exception as e:
logger.warning(f"could not generate training curves: {e}")
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
client_pred, server_pred, domain_embedding, labels = dataset.load_predictions(args.future_prediction)
client_pred, server_pred = client_pred.value, server_pred.value
logger.info("plot pr curve")
visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path))
@ -229,9 +235,7 @@ def main_visualization():
# "{}/server_cov.png".format(args.model_path),
# normalize=False, title="Server Confusion Matrix")
logger.info("visualize embedding")
model = load_model(args.embedding_model)
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))