Authored by quillotm
1 parent 8df8f5b238
Exists in

### By default, kmeans mahalanobis is with constrains

Showing 1 changed file with 0 additions and 6 deletions

volia/clustering_modules/kmeans_mahalanobis.py

 1 1 2 2 3 from sklearn.cluster import KMeans 3 from sklearn.cluster import KMeans 4 import pickle 4 import pickle 5 import numpy as np 5 import numpy as np 6 import matplotlib.pyplot as plt 6 import matplotlib.pyplot as plt 7 from sklearn.manifold import TSNE 7 from sklearn.manifold import TSNE 8 from abstract_clustering import AbstractClustering 8 from abstract_clustering import AbstractClustering 9 9 10 class kmeansMahalanobis(): 10 class kmeansMahalanobis(): 11 def __init__(self, constrained: bool = False): 11 def __init__(self, constrained: bool = False): 12 """ 12 """ 13 13 14 """ 14 """ 15 self.C = None 15 self.C = None 16 self.L = None 16 self.L = None 17 self.K = None 17 self.K = None 18 self.constrained = constrained 18 self.constrained = constrained 19 19 20 def predict(self, features): 20 def predict(self, features): 21 """ 21 """ 22 22 23 @param features: 23 @param features: 24 @return: 24 @return: 25 """ 25 """ 26 N = features.shape[0] 26 N = features.shape[0] 27 distances = np.zeros((N, self.K)) 27 distances = np.zeros((N, self.K)) 28 for n in range(N): 28 for n in range(N): 29 for k in range(self.K): 29 for k in range(self.K): 30 distances[n][k] = self._dist(features[n], self.C[k], self.L[k]) 30 distances[n][k] = self._dist(features[n], self.C[k], self.L[k]) 31 closest_cluster = np.argmin(distances, axis=1) 31 closest_cluster = np.argmin(distances, axis=1) 32 return closest_cluster 32 return closest_cluster 33 33 34 def load(self, model_path): 34 def load(self, model_path): 35 """ 35 """ 36 36 37 @param model_path: 37 @param model_path: 38 @return: 38 @return: 39 """ 39 """ 40 data = None 40 data = None 41 with open(model_path, "rb") as f: 41 with open(model_path, "rb") as f: 42 data = pickle.load(f) 42 data = pickle.load(f) 43 if data is None: 43 if data is None: 44 raise Exception("Le modèle n'a pas pu être chargé") 44 raise Exception("Le modèle n'a pas pu être chargé") 45 else: 45 else: 46 self.C = data["C"] 46 self.C = data["C"] 47 self.L = data["L"] 47 self.L = data["L"] 48 self.K = data["K"] 48 self.K = data["K"] 49 self.constrained = data["constrained"] 49 self.constrained = data["constrained"] 50 50 51 def save(self, modelpath: str): 51 def save(self, modelpath: str): 52 """ 52 """ 53 53 54 @param modelpath: 54 @param modelpath: 55 @return: 55 @return: 56 """ 56 """ 57 data = { 57 data = { 58 "C": self.C, 58 "C": self.C, 59 "L": self.L, 59 "L": self.L, 60 "K": self.K, 60 "K": self.K, 61 "constrained": self.constrained 61 "constrained": self.constrained 62 } 62 } 63 with open(modelpath, "wb") as f: 63 with open(modelpath, "wb") as f: 64 pickle.dump(data, f) 64 pickle.dump(data, f) 65 65 66 def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): 66 def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): 67 results = [] 67 results = [] 68 for i in range(ninit): 68 for i in range(ninit): 69 results.append(self._train(features, k, tol, maxiter, debug)) 69 results.append(self._train(features, k, tol, maxiter, debug)) 70 losses = [v["loss"] for v in results] 70 losses = [v["loss"] for v in results] 71 best = results[losses.index(min(losses))] 71 best = results[losses.index(min(losses))] 72 if debug: 72 if debug: 73 print(f"best: {best['loss']} loss") 73 print(f"best: {best['loss']} loss") 74 self.C = best["C"] 74 self.C = best["C"] 75 self.L = best["L"] 75 self.L = best["L"] 76 self.K = best["K"] 76 self.K = best["K"] 77 77 78 def _initialize_model(self, X, number_clusters): 78 def _initialize_model(self, X, number_clusters): 79 d = X.shape[1] 79 d = X.shape[1] 80 C = X[np.random.choice(X.shape[0], number_clusters)] 80 C = X[np.random.choice(X.shape[0], number_clusters)] 81 L = np.zeros((number_clusters, d, d)) 81 L = np.zeros((number_clusters, d, d)) 82 for k in range(number_clusters): 82 for k in range(number_clusters): 83 L[k] = np.identity(d) 83 L[k] = np.identity(d) 84 return C, L 84 return C, L 85 85 86 def _dist(self, a, b, l): 86 def _dist(self, a, b, l): 87 ''' 87 ''' 88 Distance euclidienne with mahalanobis 88 Distance euclidienne with mahalanobis 89 ''' 89 ''' 90 a = np.reshape(a, (-1, 1)) 90 a = np.reshape(a, (-1, 1)) 91 b = np.reshape(b, (-1, 1)) 91 b = np.reshape(b, (-1, 1)) 92 result = np.transpose(a - b).dot(l).dot(a - b)[0][0] 92 result = np.transpose(a - b).dot(l).dot(a - b)[0][0] 93 return result 93 return result 94 94 95 def _plot_iteration(self, iteration, points, clusters, centers): 95 def _plot_iteration(self, iteration, points, clusters, centers): 96 fig = plt.figure() 96 fig = plt.figure() 97 ax = fig.add_subplot(111) 97 ax = fig.add_subplot(111) 98 scatter = ax.scatter(points[:, 0], points[:, 1], c=clusters, s=50) 98 scatter = ax.scatter(points[:, 0], points[:, 1], c=clusters, s=50) 99 100 #for center in centers: 101 # ax.scatter(center[0], center[1], s=50, c='red', marker='+') 102 ax.scatter(centers[:, 0], centers[:, 1], s=50, c='red', marker='+') 99 ax.scatter(centers[:, 0], centers[:, 1], s=50, c='red', marker='+') 103 104 ax.set_xlabel('x') 100 ax.set_xlabel('x') 105 ax.set_ylabel('y') 101 ax.set_ylabel('y') 106 plt.colorbar(scatter) 102 plt.colorbar(scatter) 107 #plt.ylim(0, 1) 108 #plt.xlim(0, 1) 109 plt.savefig("test_" + str(iteration) + ".pdf") 103 plt.savefig("test_" + str(iteration) + ".pdf") 110 104 111 def _train(self, features, K: int, tol: float, maxiter: int, debug: bool=False): 105 def _train(self, features, K: int, tol: float, maxiter: int, debug: bool=False): 112 X = features 106 X = features 113 N = X.shape[0] 107 N = X.shape[0] 114 d = X.shape[1] 108 d = X.shape[1] 115 109 116 C, L = self._initialize_model(X, K) 110 C, L = self._initialize_model(X, K) 117 self.C = C 111 self.C = C 118 self.L = L 112 self.L = L 119 self.K = K 113 self.K = K 120 114 121 end_algo = False 115 end_algo = False 122 i = 0 116 i = 0 123 while not end_algo: 117 while not end_algo: 124 if debug: 118 if debug: 125 print("Iteration: ", i) 119 print("Iteration: ", i) 126 120 127 # Calcul matrix distance 121 # Calcul matrix distance 128 distances = np.zeros((N, self.K)) 122 distances = np.zeros((N, self.K)) 129 123 130 for n in range(N): 124 for n in range(N): 131 for k in range(self.K): 125 for k in range(self.K): 132 distances[n][k] = self._dist(X[n], self.C[k], self.L[k]) 126 distances[n][k] = self._dist(X[n], self.C[k], self.L[k]) 133 127 134 closest_cluster = np.argmin(distances, axis=1) 128 closest_cluster = np.argmin(distances, axis=1) 135 129 136 loss = np.sum(distances[np.arange(len(distances)), closest_cluster]) 130 loss = np.sum(distances[np.arange(len(distances)), closest_cluster]) 137 if debug: 131 if debug: 138 print(f"loss {loss}") 132 print(f"loss {loss}") 139 133 140 134 141 # -- Debug tool ---------------------- 135 # -- Debug tool ---------------------- 142 if debug and i % 1 == 0: 136 if debug and i % 1 == 0: 143 # TSNE if needed 137 # TSNE if needed 144 X_embedded = np.concatenate((X, self.C), axis=0) 138 X_embedded = np.concatenate((X, self.C), axis=0) 145 if d > 2: 139 if d > 2: 146 X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, self.C), axis=0)) 140 X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, self.C), axis=0)) 147 141 148 # Then plot 142 # Then plot 149 self._plot_iteration( 143 self._plot_iteration( 150 i, 144 i, 151 X_embedded[:X.shape[0]], 145 X_embedded[:X.shape[0]], 152 closest_cluster, 146 closest_cluster, 153 X_embedded[X.shape[0]:] 147 X_embedded[X.shape[0]:] 154 ) 148 ) 155 # ------------------------------------ 149 # ------------------------------------ 156 150 157 old_c = self.C.copy() 151 old_c = self.C.copy() 158 for k in range(self.K): 152 for k in range(self.K): 159 # Find subset of X with values closed to the centroid c_k. 153 # Find subset of X with values closed to the centroid c_k. 160 X_sub = np.where(closest_cluster == k) 154 X_sub = np.where(closest_cluster == k) 161 X_sub = np.take(X, X_sub[0], axis=0) 155 X_sub = np.take(X, X_sub[0], axis=0) 162 if X_sub.shape[0] == 0: 156 if X_sub.shape[0] == 0: 163 continue 157 continue 164 158 165 C_new = np.mean(X_sub, axis=0) 159 C_new = np.mean(X_sub, axis=0) 166 160 167 # -- COMPUTE NEW LAMBDA (here named K) -- 161 # -- COMPUTE NEW LAMBDA (here named K) -- 168 K_new = np.zeros((self.L.shape[1], self.L.shape[2])) 162 K_new = np.zeros((self.L.shape[1], self.L.shape[2])) 169 tmp = np.zeros((self.L.shape[1], self.L.shape[2])) 163 tmp = np.zeros((self.L.shape[1], self.L.shape[2])) 170 for x in X_sub: 164 for x in X_sub: 171 x = np.reshape(x, (-1, 1)) 165 x = np.reshape(x, (-1, 1)) 172 c_tmp = np.reshape(C_new, (-1, 1)) 166 c_tmp = np.reshape(C_new, (-1, 1)) 173 #K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose()) 167 #K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose()) 174 168 175 tmp = tmp + (x - c_tmp).dot((x - c_tmp).transpose()) 169 tmp = tmp + (x - c_tmp).dot((x - c_tmp).transpose()) 176 if self.constrained: 170 if self.constrained: 177 K_new = (tmp / X_sub.shape[0]) / np.power(np.linalg.det((tmp / X_sub.shape[0])), 1/d) 171 K_new = (tmp / X_sub.shape[0]) / np.power(np.linalg.det((tmp / X_sub.shape[0])), 1/d) 178 else: 172 else: 179 K_new = tmp / X_sub.shape[0] 173 K_new = tmp / X_sub.shape[0] 180 K_new = np.linalg.pinv(K_new) 174 K_new = np.linalg.pinv(K_new) 181 175 182 #if end_algo and (not (self.C[k] == C_new).all()): # If the same stop 176 #if end_algo and (not (self.C[k] == C_new).all()): # If the same stop 183 # end_algo = False 177 # end_algo = False 184 self.C[k] = C_new 178 self.C[k] = C_new 185 self.L[k] = K_new 179 self.L[k] = K_new 186 180 187 181 188 diff = np.sum(np.absolute((self.C - old_c) / old_c * 100)) 182 diff = np.sum(np.absolute((self.C - old_c) / old_c * 100)) 189 if diff > tol: 183 if diff > tol: 190 end_algo = False 184 end_algo = False 191 if debug: 185 if debug: 192 print(f"{diff}") 186 print(f"{diff}") 193 else: 187 else: 194 if debug: 188 if debug: 195 print(f"Tolerance threshold {tol} reached with diff {diff}") 189 print(f"Tolerance threshold {tol} reached with diff {diff}") 196 end_algo = True 190 end_algo = True 197 191 198 i = i + 1 192 i = i + 1 199 if i > maxiter: 193 if i > maxiter: 200 end_algo = True 194 end_algo = True 201 if debug: 195 if debug: 202 print(f"Iteration {maxiter} reached") 196 print(f"Iteration {maxiter} reached") 203 return { 197 return { 204 "loss": loss, 198 "loss": loss, 205 "C": self.C, 199 "C": self.C, 206 "K": self.K, 200 "K": self.K, 207 "L": self.L 201 "L": self.L 208 } 202 }