diff --git a/arguments.py b/arguments.py index cf01292..c336487 100644 --- a/arguments.py +++ b/arguments.py @@ -6,14 +6,14 @@ parser = argparse.ArgumentParser() parser.add_argument("--mode", action="store", dest="mode", default="") -parser.add_argument("--train", action="store", dest="train_data", +# parser.add_argument("--train", action="store", dest="train_data", +# default="data/full_dataset.csv.tar.gz") + +parser.add_argument("--data", action="store", dest="data", default="data/full_dataset.csv.tar.gz") -parser.add_argument("--data", action="store", dest="train_data", - default="data/full_dataset.csv.tar.gz") - -parser.add_argument("--test", action="store", dest="test_data", - default="data/full_future_dataset.csv.tar.gz") +# parser.add_argument("--test", action="store", dest="test_data", +# default="data/full_future_dataset.csv.tar.gz") parser.add_argument("--hyper_result", action="store", dest="hyperband_results", default="") @@ -117,9 +117,9 @@ def get_model_args(args): "embedding_model": os.path.join(model_path, "embd.h5"), "clf_model": os.path.join(model_path, "clf.h5"), "train_log": os.path.join(model_path, "train.log.csv"), - "train_h5data": args.train_data, - "test_h5data": args.test_data, - "future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred") + # "train_h5data": args.train_data, + # "test_h5data": args.test_data, + "future_prediction": os.path.join(model_path, f"{os.path.basename(args.data)}_pred") } for model_path in args.model_paths] @@ -130,7 +130,7 @@ def parse(): args.embedding_model = os.path.join(args.model_path, "embd.h5") args.clf_model = os.path.join(args.model_path, "clf.h5") args.train_log = os.path.join(args.model_path, "train.log.csv") - args.train_h5data = args.train_data - args.test_h5data = args.test_data - args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred") + # args.train_h5data = args.train_data + # args.test_h5data = args.test_data + args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.data)}_pred") return args diff --git a/fancy.sh b/fancy.sh index af8665f..d40d7cb 100644 --- a/fancy.sh +++ b/fancy.sh @@ -3,22 +3,22 @@ RESDIR=$1 DATADIR=$2 -python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_final --test ${DATADIR}/futureData.csv --model_output both -python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_inter --test ${DATADIR}/futureData.csv --model_output both -python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_staggered --test ${DATADIR}/futureData.csv --model_output both -python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_final --test ${DATADIR}/futureData.csv --model_output client -#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_inter --test ${DATADIR}/futureData.csv --model_output client +python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_final --data ${DATADIR}/futureData.csv --model_output both +python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_inter --data ${DATADIR}/futureData.csv --model_output both +#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_staggered --data ${DATADIR}/futureData.csv --model_output both +python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_final --data ${DATADIR}/futureData.csv --model_output client +#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_inter --data ${DATADIR}/futureData.csv --model_output client -#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_final --test ${DATADIR}/futureData.csv --model_output both -#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_inter --test ${DATADIR}/futureData.csv --model_output both -#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_final --test ${DATADIR}/futureData.csv --model_output client -#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_inter --test ${DATADIR}/futureData.csv --model_output client +#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_final --data ${DATADIR}/futureData.csv --model_output both +#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_inter --data ${DATADIR}/futureData.csv --model_output both +#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_final --data ${DATADIR}/futureData.csv --model_output client +#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_inter --data ${DATADIR}/futureData.csv --model_output client -#python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \ +#python3 main.py --mode all_fancy --batch 256 --data ${DATADIR}/futureData.csv \ # --models ${RESDIR}/*_small_*/ --out-prefix ${RESDIR}/small -#python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \ +#python3 main.py --mode all_fancy --batch 256 --data ${DATADIR}/futureData.csv \ # --models ${RESDIR}/*_medium_*/ --out-prefix ${RESDIR}/medium -python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \ +python3 main.py --mode all_fancy --batch 256 --data ${DATADIR}/futureData.csv \ --models ${RESDIR}/*/ --out-prefix ${RESDIR}/all diff --git a/main.py b/main.py index 06d038c..12d1136 100644 --- a/main.py +++ b/main.py @@ -589,6 +589,12 @@ def plot_overall_result(): visualize.plot_legend() visualize.plot_save(f"{path}/{vis}_all.png") + for cat, models in results.items(): + visualize.plot_clf() + visualize.plot_error_bars(models) + visualize.plot_legend() + visualize.plot_save(f"{path}/error_bars_{cat}.png") + def main(): if "train" == args.mode: @@ -605,6 +611,8 @@ def main(): main_visualize_all() if "beta" == args.mode: main_beta() + if "all_beta" == args.mode: + plot_overall_result() if __name__ == "__main__": diff --git a/visualize.py b/visualize.py index 39ad3a1..3e35c05 100644 --- a/visualize.py +++ b/visualize.py @@ -2,6 +2,7 @@ import os import matplotlib.pyplot as plt import numpy as np +import pandas as pd from scipy import interpolate from sklearn.decomposition import TruncatedSVD from sklearn.manifold import TSNE @@ -211,6 +212,18 @@ def plot_training_curve(logs, key, path, dpi=600): plt.close() +def plot_error_bars(results): + rates = [] + for m, r in results.items(): + if m == "all": continue + rates.append((r / r.sum(axis=0, keepdims=True)).flatten()) + rates = pd.DataFrame(np.vstack(rates), columns=("TN", "FP", "FN", "TP")) + + ax = rates.mean().plot.bar(yerr=rates.std()) + for p in ax.patches: + ax.annotate(str(np.round(p.get_height(), 4)), (p.get_x(), 0.5)) + + def plot_embedding(domain_embedding, labels, path, dpi=600, method="svd"): if method == "svd": red = TruncatedSVD(n_components=2)