From d4507c2683cd0fcec7ee302ad1a4e430b61cd6cd Mon Sep 17 00:00:00 2001 From: quillotm Date: Sat, 14 Aug 2021 12:42:28 +0200 Subject: [PATCH] We do not update a cluster if it is not associated to any data point. --- volia/clustering_modules/kmeans_mahalanobis.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/volia/clustering_modules/kmeans_mahalanobis.py b/volia/clustering_modules/kmeans_mahalanobis.py index 0779cf8..6848cdd 100644 --- a/volia/clustering_modules/kmeans_mahalanobis.py +++ b/volia/clustering_modules/kmeans_mahalanobis.py @@ -27,7 +27,6 @@ class kmeansMahalanobis(): for n in range(N): for k in range(self.K): distances[n][k] = self._dist(features[n], self.C[k], self.L[k]) - print(distances) closest_cluster = np.argmin(distances, axis=1) return closest_cluster @@ -119,7 +118,6 @@ class kmeansMahalanobis(): for n in range(N): for k in range(self.K): distances[n][k] = self._dist(X[n], self.C[k], self.L[k]) - print(distances) closest_cluster = np.argmin(distances, axis=1) if i % 1 == 0: # -- Debug tool ---------------------- @@ -140,6 +138,8 @@ class kmeansMahalanobis(): # Find subset of X with values closed to the centroid c_k. X_sub = np.where(closest_cluster == k) X_sub = np.take(X, X_sub[0], axis=0) + if X_sub.shape[0] == 0: + continue np.mean(X_sub, axis=0) C_new = np.mean(X_sub, axis=0) @@ -150,7 +150,7 @@ class kmeansMahalanobis(): c_tmp = np.reshape(C_new, (-1, 1)) K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose()) K_new = K_new / X_sub.shape[0] - K_new = np.linalg.inv(K_new) + K_new = np.linalg.pinv(K_new) if end_algo and (not (self.C[k] == C_new).all()): # If the same stop end_algo = False -- 1.8.2.3