add error bar vis, comment unused parameters from parser
This commit is contained in:
parent
b24fa770f9
commit
345afbaef5
24
arguments.py
24
arguments.py
@ -6,14 +6,14 @@ parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--mode", action="store", dest="mode",
|
||||
default="")
|
||||
|
||||
parser.add_argument("--train", action="store", dest="train_data",
|
||||
# parser.add_argument("--train", action="store", dest="train_data",
|
||||
# default="data/full_dataset.csv.tar.gz")
|
||||
|
||||
parser.add_argument("--data", action="store", dest="data",
|
||||
default="data/full_dataset.csv.tar.gz")
|
||||
|
||||
parser.add_argument("--data", action="store", dest="train_data",
|
||||
default="data/full_dataset.csv.tar.gz")
|
||||
|
||||
parser.add_argument("--test", action="store", dest="test_data",
|
||||
default="data/full_future_dataset.csv.tar.gz")
|
||||
# parser.add_argument("--test", action="store", dest="test_data",
|
||||
# default="data/full_future_dataset.csv.tar.gz")
|
||||
|
||||
parser.add_argument("--hyper_result", action="store", dest="hyperband_results",
|
||||
default="")
|
||||
@ -117,9 +117,9 @@ def get_model_args(args):
|
||||
"embedding_model": os.path.join(model_path, "embd.h5"),
|
||||
"clf_model": os.path.join(model_path, "clf.h5"),
|
||||
"train_log": os.path.join(model_path, "train.log.csv"),
|
||||
"train_h5data": args.train_data,
|
||||
"test_h5data": args.test_data,
|
||||
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred")
|
||||
# "train_h5data": args.train_data,
|
||||
# "test_h5data": args.test_data,
|
||||
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.data)}_pred")
|
||||
} for model_path in args.model_paths]
|
||||
|
||||
|
||||
@ -130,7 +130,7 @@ def parse():
|
||||
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
||||
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
||||
args.train_log = os.path.join(args.model_path, "train.log.csv")
|
||||
args.train_h5data = args.train_data
|
||||
args.test_h5data = args.test_data
|
||||
args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred")
|
||||
# args.train_h5data = args.train_data
|
||||
# args.test_h5data = args.test_data
|
||||
args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.data)}_pred")
|
||||
return args
|
||||
|
24
fancy.sh
24
fancy.sh
@ -3,22 +3,22 @@
|
||||
RESDIR=$1
|
||||
DATADIR=$2
|
||||
|
||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_final --test ${DATADIR}/futureData.csv --model_output both
|
||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_inter --test ${DATADIR}/futureData.csv --model_output both
|
||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_staggered --test ${DATADIR}/futureData.csv --model_output both
|
||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_final --test ${DATADIR}/futureData.csv --model_output client
|
||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_inter --test ${DATADIR}/futureData.csv --model_output client
|
||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_final --data ${DATADIR}/futureData.csv --model_output both
|
||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_inter --data ${DATADIR}/futureData.csv --model_output both
|
||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_staggered --data ${DATADIR}/futureData.csv --model_output both
|
||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_final --data ${DATADIR}/futureData.csv --model_output client
|
||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_inter --data ${DATADIR}/futureData.csv --model_output client
|
||||
|
||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_final --test ${DATADIR}/futureData.csv --model_output both
|
||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_inter --test ${DATADIR}/futureData.csv --model_output both
|
||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_final --test ${DATADIR}/futureData.csv --model_output client
|
||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_inter --test ${DATADIR}/futureData.csv --model_output client
|
||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_final --data ${DATADIR}/futureData.csv --model_output both
|
||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_inter --data ${DATADIR}/futureData.csv --model_output both
|
||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_final --data ${DATADIR}/futureData.csv --model_output client
|
||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_inter --data ${DATADIR}/futureData.csv --model_output client
|
||||
|
||||
#python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \
|
||||
#python3 main.py --mode all_fancy --batch 256 --data ${DATADIR}/futureData.csv \
|
||||
# --models ${RESDIR}/*_small_*/ --out-prefix ${RESDIR}/small
|
||||
|
||||
#python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \
|
||||
#python3 main.py --mode all_fancy --batch 256 --data ${DATADIR}/futureData.csv \
|
||||
# --models ${RESDIR}/*_medium_*/ --out-prefix ${RESDIR}/medium
|
||||
|
||||
python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \
|
||||
python3 main.py --mode all_fancy --batch 256 --data ${DATADIR}/futureData.csv \
|
||||
--models ${RESDIR}/*/ --out-prefix ${RESDIR}/all
|
||||
|
8
main.py
8
main.py
@ -589,6 +589,12 @@ def plot_overall_result():
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{path}/{vis}_all.png")
|
||||
|
||||
for cat, models in results.items():
|
||||
visualize.plot_clf()
|
||||
visualize.plot_error_bars(models)
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{path}/error_bars_{cat}.png")
|
||||
|
||||
|
||||
def main():
|
||||
if "train" == args.mode:
|
||||
@ -605,6 +611,8 @@ def main():
|
||||
main_visualize_all()
|
||||
if "beta" == args.mode:
|
||||
main_beta()
|
||||
if "all_beta" == args.mode:
|
||||
plot_overall_result()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
13
visualize.py
13
visualize.py
@ -2,6 +2,7 @@ import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy import interpolate
|
||||
from sklearn.decomposition import TruncatedSVD
|
||||
from sklearn.manifold import TSNE
|
||||
@ -211,6 +212,18 @@ def plot_training_curve(logs, key, path, dpi=600):
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_error_bars(results):
|
||||
rates = []
|
||||
for m, r in results.items():
|
||||
if m == "all": continue
|
||||
rates.append((r / r.sum(axis=0, keepdims=True)).flatten())
|
||||
rates = pd.DataFrame(np.vstack(rates), columns=("TN", "FP", "FN", "TP"))
|
||||
|
||||
ax = rates.mean().plot.bar(yerr=rates.std())
|
||||
for p in ax.patches:
|
||||
ax.annotate(str(np.round(p.get_height(), 4)), (p.get_x(), 0.5))
|
||||
|
||||
|
||||
def plot_embedding(domain_embedding, labels, path, dpi=600, method="svd"):
|
||||
if method == "svd":
|
||||
red = TruncatedSVD(n_components=2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user