refactor cmd argument to have single value for mode
This commit is contained in:
parent
e24f596f40
commit
f4da147688
7
Makefile
7
Makefile
@ -1,5 +1,10 @@
|
||||
run:
|
||||
python3 main.py --modes train --batch 128 --model results/test --train data/rk_mini.csv.gz --epochs 10
|
||||
python3 main.py --modes train --train data/rk_mini.csv.gz --model results/test --epochs 10 \
|
||||
--hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights
|
||||
|
||||
run_new:
|
||||
python3 main.py --modes train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 \
|
||||
--hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights --new_model
|
||||
|
||||
test:
|
||||
python3 main.py --modes test --batch 128 --model results/test --test data/rk_mini.csv.gz
|
||||
|
10
arguments.py
10
arguments.py
@ -3,15 +3,19 @@ import os
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--modes", action="store", dest="modes", nargs="+",
|
||||
default=[])
|
||||
parser.add_argument("--mode", action="store", dest="mode",
|
||||
default="")
|
||||
|
||||
parser.add_argument("--train", action="store", dest="train_data",
|
||||
default="data/full_dataset.csv.tar.gz")
|
||||
|
||||
parser.add_argument("--data", action="store", dest="train_data",
|
||||
default="data/full_dataset.csv.tar.gz")
|
||||
|
||||
parser.add_argument("--test", action="store", dest="test_data",
|
||||
default="data/full_future_dataset.csv.tar.gz")
|
||||
|
||||
|
||||
parser.add_argument("--model", action="store", dest="model_path",
|
||||
default="results/model_x")
|
||||
|
||||
@ -74,5 +78,5 @@ def parse():
|
||||
args.train_log = os.path.join(args.model_path, "train.log.csv")
|
||||
args.train_h5data = args.train_data + ".h5"
|
||||
args.test_h5data = args.test_data + ".h5"
|
||||
args.future_prediction = os.path.join(args.model_path, "future_predict.npy")
|
||||
args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred.h5")
|
||||
return args
|
||||
|
12
dataset.py
12
dataset.py
@ -246,3 +246,15 @@ def load_or_generate_domains(train_data, domain_length):
|
||||
user_flow_df.groupby(user_flow_df.domain).mean()
|
||||
|
||||
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix()
|
||||
|
||||
|
||||
def save_predictions(path, c_pred, s_pred):
|
||||
f = h5py.File(path, "w")
|
||||
f.create_dataset("client", data=c_pred)
|
||||
f.create_dataset("server", data=s_pred)
|
||||
f.close()
|
||||
|
||||
|
||||
def load_predictions(path):
|
||||
f = h5py.File(path, "r")
|
||||
return f["client"], f["server"]
|
||||
|
80
main.py
80
main.py
@ -175,41 +175,16 @@ def main_test():
|
||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||
args.domain_length, args.window)
|
||||
clf = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||
y_pred = clf.predict([domain_val, flow_val],
|
||||
c_pred, s_pred = clf.predict([domain_val, flow_val],
|
||||
batch_size=args.batch_size,
|
||||
verbose=1)
|
||||
np.save(args.future_prediction, y_pred)
|
||||
|
||||
# char_dict = dataset.get_character_dict()
|
||||
# user_flow_df = dataset.get_user_flow_data(args.test_data)
|
||||
# domains = user_flow_df.domain.unique()[:-1]
|
||||
#
|
||||
# def get_domain_features_reduced(d):
|
||||
# return dataset.get_domain_features(d[0], char_dict, args.domain_length)
|
||||
#
|
||||
# domain_features = []
|
||||
# for ds in domains:
|
||||
# domain_features.append(np.apply_along_axis(get_domain_features_reduced, 2, np.atleast_3d(ds)))
|
||||
#
|
||||
# model = load_model(args.embedding_model)
|
||||
# domain_features = np.stack(domain_features).reshape((-1, 40))
|
||||
# pred = model.predict(domain_features, batch_size=args.batch_size, verbose=1)
|
||||
#
|
||||
# np.save("/tmp/rk/domains.npy", domains)
|
||||
# np.save("/tmp/rk/domain_features.npy", domain_features)
|
||||
# np.save("/tmp/rk/domain_embd.npy", pred)
|
||||
|
||||
|
||||
def main_embedding():
|
||||
model = load_model(args.embedding_model)
|
||||
domain_encs, labels = dataset.load_or_generate_domains(args.train_data, args.domain_length)
|
||||
domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||
visualize.plot_embedding(domain_embedding, labels, path="results/pp3/embd.png")
|
||||
dataset.save_predictions(args.future_prediction, c_pred, s_pred)
|
||||
|
||||
|
||||
def main_visualization():
|
||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||
args.domain_length, args.window)
|
||||
client_val, server_val = client_val.value, server_val.value
|
||||
logger.info("plot model")
|
||||
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||
visualize.plot_model(model, os.path.join(args.model_path, "model.png"))
|
||||
@ -221,28 +196,27 @@ def main_visualization():
|
||||
except Exception as e:
|
||||
logger.warning(f"could not generate training curves: {e}")
|
||||
|
||||
client_pred, server_pred = np.load(args.future_prediction)
|
||||
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
|
||||
client_pred, server_pred = client_pred.value, server_pred.value
|
||||
logger.info("plot pr curve")
|
||||
visualize.plot_precision_recall(client_val.value, client_pred, "{}/client_prc.png".format(args.model_path))
|
||||
visualize.plot_precision_recall(server_val.value, server_pred, "{}/server_prc.png".format(args.model_path))
|
||||
visualize.plot_precision_recall_curves(client_val.value, client_pred, "{}/client_prc2.png".format(args.model_path))
|
||||
visualize.plot_precision_recall_curves(server_val.value, server_pred, "{}/server_prc2.png".format(args.model_path))
|
||||
visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path))
|
||||
# visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
|
||||
# visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
|
||||
# visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
|
||||
logger.info("plot roc curve")
|
||||
visualize.plot_roc_curve(client_val.value, client_pred, "{}/client_roc.png".format(args.model_path))
|
||||
visualize.plot_roc_curve(server_val.value, server_pred, "{}/server_roc.png".format(args.model_path))
|
||||
visualize.plot_confusion_matrix(client_val.value.argmax(1), client_pred.argmax(1),
|
||||
visualize.plot_roc_curve(client_val, client_pred.flatten(), "{}/client_roc.png".format(args.model_path))
|
||||
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
|
||||
visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(),
|
||||
"{}/client_cov.png".format(args.model_path),
|
||||
normalize=False, title="Client Confusion Matrix")
|
||||
visualize.plot_confusion_matrix(server_val.value.argmax(1), server_pred.argmax(1),
|
||||
"{}/server_cov.png".format(args.model_path),
|
||||
normalize=False, title="Server Confusion Matrix")
|
||||
|
||||
|
||||
def main_score():
|
||||
# mask = dataset.load_mask_eval(args.data, args.test_image)
|
||||
# pred = np.load(args.pred)
|
||||
# visualize.score_model(mask, pred)
|
||||
pass
|
||||
# visualize.plot_confusion_matrix(server_val.argmax(1), server_pred.argmax(1),
|
||||
# "{}/server_cov.png".format(args.model_path),
|
||||
# normalize=False, title="Server Confusion Matrix")
|
||||
logger.info("visualize embedding")
|
||||
model = load_model(args.embedding_model)
|
||||
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||
domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
|
||||
|
||||
|
||||
def main_data():
|
||||
@ -259,19 +233,17 @@ def main_data():
|
||||
|
||||
|
||||
def main():
|
||||
if "train" in args.modes:
|
||||
if "train" == args.mode:
|
||||
main_train()
|
||||
if "hyperband" in args.modes:
|
||||
if "hyperband" == args.mode:
|
||||
main_hyperband()
|
||||
if "test" in args.modes:
|
||||
if "test" == args.mode:
|
||||
main_test()
|
||||
if "fancy" in args.modes:
|
||||
if "fancy" == args.mode:
|
||||
main_visualization()
|
||||
if "score" in args.modes:
|
||||
main_score()
|
||||
if "paul" in args.modes:
|
||||
if "paul" == args.mode:
|
||||
main_paul_best()
|
||||
if "data" in args.modes:
|
||||
if "data" == args.mode:
|
||||
main_data()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user