fix covariance normalization; add run_model script for multi times training

This commit is contained in:
René Knaebel 2017-09-11 12:42:44 +02:00
parent 6121eac448
commit 6d8d7b19f3
2 changed files with 36 additions and 7 deletions

29
run_model.sh Normal file
View File

@ -0,0 +1,29 @@
#!/usr/bin/env bash
N=$1
OUTPUT=$2
DEPTH=$3
TYPE=$4
RESDIR=$5
mkdir -p /tmp/rk/${RESDIR}
DATADIR=$6
EPOCHS=100
for i in {1..$N}
do
python main.py --mode train \
--train ${DATADIR}/currentData.csv \
--model ${RESDIR}/${OUTPUT}_${TYPE}_$i \
--epochs $EPOCHS \
--embd 128 \
--filter_embd 256 --kernel_embd 8 --dense_embd 128 \
--domain_embd 32 \
--filter_main 32 --kernel_main 8 --dense_main 1024 \
--batch 256 \
--balanced_weights \
--model_output ${OUTPUT} \
--type ${TYPE} \
--depth ${DEPTH}
done

View File

@ -104,21 +104,21 @@ def plot_confusion_matrix(y_true, y_pred, path,
"""
plt.clf()
cm = confusion_matrix(y_true, y_pred)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
thresh = cm.max() / 2.
for i, j in ((i, j) for i in range(cm.shape[0]) for j in range(cm.shape[1])):
plt.text(j, i, cm[i, j],