diff --git a/visualize.py b/visualize.py index 90db73e..812f171 100644 --- a/visualize.py +++ b/visualize.py @@ -3,7 +3,7 @@ import os import matplotlib.pyplot as plt import numpy as np from keras.utils import plot_model -from sklearn.decomposition import PCA +from sklearn.decomposition import TruncatedSVD from sklearn.metrics import ( auc, classification_report, confusion_matrix, fbeta_score, precision_recall_curve, roc_auc_score, roc_curve @@ -146,13 +146,11 @@ def plot_training_curve(logs, key, path, dpi=600): def plot_embedding(domain_embedding, labels, path, dpi=600): - pca = PCA(n_components=2) - domain_reduced = pca.fit_transform(domain_embedding) - print(pca.explained_variance_ratio_) - + svd = TruncatedSVD(n_components=2) + domain_reduced = svd.fit_transform(domain_embedding) + print(svd.explained_variance_ratio_) # use if draw subset of predictions # idx = np.random.choice(np.arange(len(domain_reduced)), 10000) - plt.scatter(domain_reduced[:, 0], domain_reduced[:, 1], c=(labels * (1, 2)).sum(1).astype(int),