fix covariance normalization; add run_model script for multi times training
This commit is contained in:
parent
6121eac448
commit
6d8d7b19f3
29
run_model.sh
Normal file
29
run_model.sh
Normal file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
|
||||
N=$1
|
||||
OUTPUT=$2
|
||||
DEPTH=$3
|
||||
TYPE=$4
|
||||
RESDIR=$5
|
||||
mkdir -p /tmp/rk/${RESDIR}
|
||||
DATADIR=$6
|
||||
|
||||
EPOCHS=100
|
||||
|
||||
for i in {1..$N}
|
||||
do
|
||||
python main.py --mode train \
|
||||
--train ${DATADIR}/currentData.csv \
|
||||
--model ${RESDIR}/${OUTPUT}_${TYPE}_$i \
|
||||
--epochs $EPOCHS \
|
||||
--embd 128 \
|
||||
--filter_embd 256 --kernel_embd 8 --dense_embd 128 \
|
||||
--domain_embd 32 \
|
||||
--filter_main 32 --kernel_main 8 --dense_main 1024 \
|
||||
--batch 256 \
|
||||
--balanced_weights \
|
||||
--model_output ${OUTPUT} \
|
||||
--type ${TYPE} \
|
||||
--depth ${DEPTH}
|
||||
done
|
14
visualize.py
14
visualize.py
@ -104,21 +104,21 @@ def plot_confusion_matrix(y_true, y_pred, path,
|
||||
"""
|
||||
plt.clf()
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
plt.imshow(cm, interpolation='nearest', cmap=cmap)
|
||||
plt.title(title)
|
||||
plt.colorbar()
|
||||
tick_marks = np.arange(len(classes))
|
||||
plt.xticks(tick_marks, classes, rotation=45)
|
||||
plt.yticks(tick_marks, classes)
|
||||
|
||||
if normalize:
|
||||
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
||||
print("Normalized confusion matrix")
|
||||
else:
|
||||
print('Confusion matrix, without normalization')
|
||||
|
||||
print(cm)
|
||||
|
||||
plt.imshow(cm, interpolation='nearest', cmap=cmap)
|
||||
plt.title(title)
|
||||
plt.colorbar()
|
||||
tick_marks = np.arange(len(classes))
|
||||
plt.xticks(tick_marks, classes, rotation=45)
|
||||
plt.yticks(tick_marks, classes)
|
||||
|
||||
thresh = cm.max() / 2.
|
||||
for i, j in ((i, j) for i in range(cm.shape[0]) for j in range(cm.shape[1])):
|
||||
plt.text(j, i, cm[i, j],
|
||||
|
Loading…
Reference in New Issue
Block a user