fix covariance normalization; add run_model script for multi times training
This commit is contained in:
		
							
								
								
									
										29
									
								
								run_model.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								run_model.sh
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
			
		||||
#!/usr/bin/env bash
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
N=$1
 | 
			
		||||
OUTPUT=$2
 | 
			
		||||
DEPTH=$3
 | 
			
		||||
TYPE=$4
 | 
			
		||||
RESDIR=$5
 | 
			
		||||
mkdir -p /tmp/rk/${RESDIR}
 | 
			
		||||
DATADIR=$6
 | 
			
		||||
 | 
			
		||||
EPOCHS=100
 | 
			
		||||
 | 
			
		||||
for i in {1..$N}
 | 
			
		||||
do
 | 
			
		||||
    python main.py --mode train \
 | 
			
		||||
                --train ${DATADIR}/currentData.csv \
 | 
			
		||||
                --model ${RESDIR}/${OUTPUT}_${TYPE}_$i \
 | 
			
		||||
                --epochs $EPOCHS \
 | 
			
		||||
                --embd 128 \
 | 
			
		||||
                --filter_embd 256 --kernel_embd 8 --dense_embd 128 \
 | 
			
		||||
                --domain_embd 32 \
 | 
			
		||||
                --filter_main 32  --kernel_main 8 --dense_main 1024 \
 | 
			
		||||
                --batch 256 \
 | 
			
		||||
                --balanced_weights \
 | 
			
		||||
                --model_output ${OUTPUT} \
 | 
			
		||||
                --type ${TYPE} \
 | 
			
		||||
                --depth ${DEPTH}
 | 
			
		||||
done
 | 
			
		||||
							
								
								
									
										14
									
								
								visualize.py
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								visualize.py
									
									
									
									
									
								
							@@ -104,21 +104,21 @@ def plot_confusion_matrix(y_true, y_pred, path,
 | 
			
		||||
    """
 | 
			
		||||
    plt.clf()
 | 
			
		||||
    cm = confusion_matrix(y_true, y_pred)
 | 
			
		||||
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
 | 
			
		||||
    plt.title(title)
 | 
			
		||||
    plt.colorbar()
 | 
			
		||||
    tick_marks = np.arange(len(classes))
 | 
			
		||||
    plt.xticks(tick_marks, classes, rotation=45)
 | 
			
		||||
    plt.yticks(tick_marks, classes)
 | 
			
		||||
 | 
			
		||||
    if normalize:
 | 
			
		||||
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
 | 
			
		||||
        print("Normalized confusion matrix")
 | 
			
		||||
    else:
 | 
			
		||||
        print('Confusion matrix, without normalization')
 | 
			
		||||
 | 
			
		||||
    print(cm)
 | 
			
		||||
 | 
			
		||||
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
 | 
			
		||||
    plt.title(title)
 | 
			
		||||
    plt.colorbar()
 | 
			
		||||
    tick_marks = np.arange(len(classes))
 | 
			
		||||
    plt.xticks(tick_marks, classes, rotation=45)
 | 
			
		||||
    plt.yticks(tick_marks, classes)
 | 
			
		||||
 | 
			
		||||
    thresh = cm.max() / 2.
 | 
			
		||||
    for i, j in ((i, j) for i in range(cm.shape[0]) for j in range(cm.shape[1])):
 | 
			
		||||
        plt.text(j, i, cm[i, j],
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user