refactor cmd argument to have single value for mode

This commit is contained in:
René Knaebel 2017-07-30 15:49:37 +02:00
parent e24f596f40
commit f4da147688
4 changed files with 55 additions and 62 deletions

View File

@ -1,11 +1,16 @@
run: run:
python3 main.py --modes train --batch 128 --model results/test --train data/rk_mini.csv.gz --epochs 10 python3 main.py --modes train --train data/rk_mini.csv.gz --model results/test --epochs 10 \
--hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights
run_new:
python3 main.py --modes train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 \
--hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights --new_model
test: test:
python3 main.py --modes test --batch 128 --model results/test --test data/rk_mini.csv.gz python3 main.py --modes test --batch 128 --model results/test --test data/rk_mini.csv.gz
fancy: fancy:
python3 main.py --modes fancy --batch 128 --model results/test --test data/rk_mini.csv.gz python3 main.py --modes fancy --batch 128 --model results/test --test data/rk_mini.csv.gz
hyper: hyper:
python3 main.py --modes hyperband --batch 64 --train data/rk_data.csv.gz python3 main.py --modes hyperband --batch 64 --train data/rk_data.csv.gz

View File

@ -3,15 +3,19 @@ import os
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--modes", action="store", dest="modes", nargs="+", parser.add_argument("--mode", action="store", dest="mode",
default=[]) default="")
parser.add_argument("--train", action="store", dest="train_data", parser.add_argument("--train", action="store", dest="train_data",
default="data/full_dataset.csv.tar.gz") default="data/full_dataset.csv.tar.gz")
parser.add_argument("--data", action="store", dest="train_data",
default="data/full_dataset.csv.tar.gz")
parser.add_argument("--test", action="store", dest="test_data", parser.add_argument("--test", action="store", dest="test_data",
default="data/full_future_dataset.csv.tar.gz") default="data/full_future_dataset.csv.tar.gz")
parser.add_argument("--model", action="store", dest="model_path", parser.add_argument("--model", action="store", dest="model_path",
default="results/model_x") default="results/model_x")
@ -74,5 +78,5 @@ def parse():
args.train_log = os.path.join(args.model_path, "train.log.csv") args.train_log = os.path.join(args.model_path, "train.log.csv")
args.train_h5data = args.train_data + ".h5" args.train_h5data = args.train_data + ".h5"
args.test_h5data = args.test_data + ".h5" args.test_h5data = args.test_data + ".h5"
args.future_prediction = os.path.join(args.model_path, "future_predict.npy") args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred.h5")
return args return args

View File

@ -246,3 +246,15 @@ def load_or_generate_domains(train_data, domain_length):
user_flow_df.groupby(user_flow_df.domain).mean() user_flow_df.groupby(user_flow_df.domain).mean()
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix() return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix()
def save_predictions(path, c_pred, s_pred):
f = h5py.File(path, "w")
f.create_dataset("client", data=c_pred)
f.create_dataset("server", data=s_pred)
f.close()
def load_predictions(path):
f = h5py.File(path, "r")
return f["client"], f["server"]

84
main.py
View File

@ -175,41 +175,16 @@ def main_test():
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data, domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
args.domain_length, args.window) args.domain_length, args.window)
clf = load_model(args.clf_model, custom_objects=models.get_metrics()) clf = load_model(args.clf_model, custom_objects=models.get_metrics())
y_pred = clf.predict([domain_val, flow_val], c_pred, s_pred = clf.predict([domain_val, flow_val],
batch_size=args.batch_size, batch_size=args.batch_size,
verbose=1) verbose=1)
np.save(args.future_prediction, y_pred) dataset.save_predictions(args.future_prediction, c_pred, s_pred)
# char_dict = dataset.get_character_dict()
# user_flow_df = dataset.get_user_flow_data(args.test_data)
# domains = user_flow_df.domain.unique()[:-1]
#
# def get_domain_features_reduced(d):
# return dataset.get_domain_features(d[0], char_dict, args.domain_length)
#
# domain_features = []
# for ds in domains:
# domain_features.append(np.apply_along_axis(get_domain_features_reduced, 2, np.atleast_3d(ds)))
#
# model = load_model(args.embedding_model)
# domain_features = np.stack(domain_features).reshape((-1, 40))
# pred = model.predict(domain_features, batch_size=args.batch_size, verbose=1)
#
# np.save("/tmp/rk/domains.npy", domains)
# np.save("/tmp/rk/domain_features.npy", domain_features)
# np.save("/tmp/rk/domain_embd.npy", pred)
def main_embedding():
model = load_model(args.embedding_model)
domain_encs, labels = dataset.load_or_generate_domains(args.train_data, args.domain_length)
domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
visualize.plot_embedding(domain_embedding, labels, path="results/pp3/embd.png")
def main_visualization(): def main_visualization():
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data, domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
args.domain_length, args.window) args.domain_length, args.window)
client_val, server_val = client_val.value, server_val.value
logger.info("plot model") logger.info("plot model")
model = load_model(args.clf_model, custom_objects=models.get_metrics()) model = load_model(args.clf_model, custom_objects=models.get_metrics())
visualize.plot_model(model, os.path.join(args.model_path, "model.png")) visualize.plot_model(model, os.path.join(args.model_path, "model.png"))
@ -221,28 +196,27 @@ def main_visualization():
except Exception as e: except Exception as e:
logger.warning(f"could not generate training curves: {e}") logger.warning(f"could not generate training curves: {e}")
client_pred, server_pred = np.load(args.future_prediction) client_pred, server_pred = dataset.load_predictions(args.future_prediction)
client_pred, server_pred = client_pred.value, server_pred.value
logger.info("plot pr curve") logger.info("plot pr curve")
visualize.plot_precision_recall(client_val.value, client_pred, "{}/client_prc.png".format(args.model_path)) visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path))
visualize.plot_precision_recall(server_val.value, server_pred, "{}/server_prc.png".format(args.model_path)) # visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
visualize.plot_precision_recall_curves(client_val.value, client_pred, "{}/client_prc2.png".format(args.model_path)) # visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
visualize.plot_precision_recall_curves(server_val.value, server_pred, "{}/server_prc2.png".format(args.model_path)) # visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
logger.info("plot roc curve") logger.info("plot roc curve")
visualize.plot_roc_curve(client_val.value, client_pred, "{}/client_roc.png".format(args.model_path)) visualize.plot_roc_curve(client_val, client_pred.flatten(), "{}/client_roc.png".format(args.model_path))
visualize.plot_roc_curve(server_val.value, server_pred, "{}/server_roc.png".format(args.model_path)) # visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
visualize.plot_confusion_matrix(client_val.value.argmax(1), client_pred.argmax(1), visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(),
"{}/client_cov.png".format(args.model_path), "{}/client_cov.png".format(args.model_path),
normalize=False, title="Client Confusion Matrix") normalize=False, title="Client Confusion Matrix")
visualize.plot_confusion_matrix(server_val.value.argmax(1), server_pred.argmax(1), # visualize.plot_confusion_matrix(server_val.argmax(1), server_pred.argmax(1),
"{}/server_cov.png".format(args.model_path), # "{}/server_cov.png".format(args.model_path),
normalize=False, title="Server Confusion Matrix") # normalize=False, title="Server Confusion Matrix")
logger.info("visualize embedding")
model = load_model(args.embedding_model)
def main_score(): domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
# mask = dataset.load_mask_eval(args.data, args.test_image) domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
# pred = np.load(args.pred) visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
# visualize.score_model(mask, pred)
pass
def main_data(): def main_data():
@ -259,19 +233,17 @@ def main_data():
def main(): def main():
if "train" in args.modes: if "train" == args.mode:
main_train() main_train()
if "hyperband" in args.modes: if "hyperband" == args.mode:
main_hyperband() main_hyperband()
if "test" in args.modes: if "test" == args.mode:
main_test() main_test()
if "fancy" in args.modes: if "fancy" == args.mode:
main_visualization() main_visualization()
if "score" in args.modes: if "paul" == args.mode:
main_score()
if "paul" in args.modes:
main_paul_best() main_paul_best()
if "data" in args.modes: if "data" == args.mode:
main_data() main_data()