diff --git a/run_model.sh b/run_model.sh new file mode 100644 index 0000000..0eb5477 --- /dev/null +++ b/run_model.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + + +N=$1 +OUTPUT=$2 +DEPTH=$3 +TYPE=$4 +RESDIR=$5 +mkdir -p /tmp/rk/${RESDIR} +DATADIR=$6 + +EPOCHS=100 + +for i in {1..$N} +do + python main.py --mode train \ + --train ${DATADIR}/currentData.csv \ + --model ${RESDIR}/${OUTPUT}_${TYPE}_$i \ + --epochs $EPOCHS \ + --embd 128 \ + --filter_embd 256 --kernel_embd 8 --dense_embd 128 \ + --domain_embd 32 \ + --filter_main 32 --kernel_main 8 --dense_main 1024 \ + --batch 256 \ + --balanced_weights \ + --model_output ${OUTPUT} \ + --type ${TYPE} \ + --depth ${DEPTH} +done \ No newline at end of file diff --git a/visualize.py b/visualize.py index 28f7a63..bd330e6 100644 --- a/visualize.py +++ b/visualize.py @@ -104,21 +104,21 @@ def plot_confusion_matrix(y_true, y_pred, path, """ plt.clf() cm = confusion_matrix(y_true, y_pred) - plt.imshow(cm, interpolation='nearest', cmap=cmap) - plt.title(title) - plt.colorbar() - tick_marks = np.arange(len(classes)) - plt.xticks(tick_marks, classes, rotation=45) - plt.yticks(tick_marks, classes) if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] print("Normalized confusion matrix") else: print('Confusion matrix, without normalization') - print(cm) + plt.imshow(cm, interpolation='nearest', cmap=cmap) + plt.title(title) + plt.colorbar() + tick_marks = np.arange(len(classes)) + plt.xticks(tick_marks, classes, rotation=45) + plt.yticks(tick_marks, classes) + thresh = cm.max() / 2. for i, j in ((i, j) for i in range(cm.shape[0]) for j in range(cm.shape[1])): plt.text(j, i, cm[i, j],