fix covariance normalization; add run_model script for multi times training
This commit is contained in:
		
							
								
								
									
										29
									
								
								run_model.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								run_model.sh
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/env bash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					N=$1
 | 
				
			||||||
 | 
					OUTPUT=$2
 | 
				
			||||||
 | 
					DEPTH=$3
 | 
				
			||||||
 | 
					TYPE=$4
 | 
				
			||||||
 | 
					RESDIR=$5
 | 
				
			||||||
 | 
					mkdir -p /tmp/rk/${RESDIR}
 | 
				
			||||||
 | 
					DATADIR=$6
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					EPOCHS=100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					for i in {1..$N}
 | 
				
			||||||
 | 
					do
 | 
				
			||||||
 | 
					    python main.py --mode train \
 | 
				
			||||||
 | 
					                --train ${DATADIR}/currentData.csv \
 | 
				
			||||||
 | 
					                --model ${RESDIR}/${OUTPUT}_${TYPE}_$i \
 | 
				
			||||||
 | 
					                --epochs $EPOCHS \
 | 
				
			||||||
 | 
					                --embd 128 \
 | 
				
			||||||
 | 
					                --filter_embd 256 --kernel_embd 8 --dense_embd 128 \
 | 
				
			||||||
 | 
					                --domain_embd 32 \
 | 
				
			||||||
 | 
					                --filter_main 32  --kernel_main 8 --dense_main 1024 \
 | 
				
			||||||
 | 
					                --batch 256 \
 | 
				
			||||||
 | 
					                --balanced_weights \
 | 
				
			||||||
 | 
					                --model_output ${OUTPUT} \
 | 
				
			||||||
 | 
					                --type ${TYPE} \
 | 
				
			||||||
 | 
					                --depth ${DEPTH}
 | 
				
			||||||
 | 
					done
 | 
				
			||||||
							
								
								
									
										14
									
								
								visualize.py
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								visualize.py
									
									
									
									
									
								
							@@ -104,21 +104,21 @@ def plot_confusion_matrix(y_true, y_pred, path,
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
    plt.clf()
 | 
					    plt.clf()
 | 
				
			||||||
    cm = confusion_matrix(y_true, y_pred)
 | 
					    cm = confusion_matrix(y_true, y_pred)
 | 
				
			||||||
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
 | 
					 | 
				
			||||||
    plt.title(title)
 | 
					 | 
				
			||||||
    plt.colorbar()
 | 
					 | 
				
			||||||
    tick_marks = np.arange(len(classes))
 | 
					 | 
				
			||||||
    plt.xticks(tick_marks, classes, rotation=45)
 | 
					 | 
				
			||||||
    plt.yticks(tick_marks, classes)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if normalize:
 | 
					    if normalize:
 | 
				
			||||||
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
 | 
					        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
 | 
				
			||||||
        print("Normalized confusion matrix")
 | 
					        print("Normalized confusion matrix")
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        print('Confusion matrix, without normalization')
 | 
					        print('Confusion matrix, without normalization')
 | 
				
			||||||
 | 
					 | 
				
			||||||
    print(cm)
 | 
					    print(cm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    plt.imshow(cm, interpolation='nearest', cmap=cmap)
 | 
				
			||||||
 | 
					    plt.title(title)
 | 
				
			||||||
 | 
					    plt.colorbar()
 | 
				
			||||||
 | 
					    tick_marks = np.arange(len(classes))
 | 
				
			||||||
 | 
					    plt.xticks(tick_marks, classes, rotation=45)
 | 
				
			||||||
 | 
					    plt.yticks(tick_marks, classes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    thresh = cm.max() / 2.
 | 
					    thresh = cm.max() / 2.
 | 
				
			||||||
    for i, j in ((i, j) for i in range(cm.shape[0]) for j in range(cm.shape[1])):
 | 
					    for i, j in ((i, j) for i in range(cm.shape[0]) for j in range(cm.shape[1])):
 | 
				
			||||||
        plt.text(j, i, cm[i, j],
 | 
					        plt.text(j, i, cm[i, j],
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user