Commit 55bcf758f37969188cf737d1a6ff23fe803ebddd

Authored by quillotm
1 parent 8df8f5b238
Exists in master

By default, kmeans mahalanobis is with constrains

Showing 1 changed file with 0 additions and 6 deletions Inline Diff

volia/clustering_modules/kmeans_mahalanobis.py
1 1
2 2
3 from sklearn.cluster import KMeans 3 from sklearn.cluster import KMeans
4 import pickle 4 import pickle
5 import numpy as np 5 import numpy as np
6 import matplotlib.pyplot as plt 6 import matplotlib.pyplot as plt
7 from sklearn.manifold import TSNE 7 from sklearn.manifold import TSNE
8 from abstract_clustering import AbstractClustering 8 from abstract_clustering import AbstractClustering
9 9
10 class kmeansMahalanobis(): 10 class kmeansMahalanobis():
11 def __init__(self, constrained: bool = False): 11 def __init__(self, constrained: bool = False):
12 """ 12 """
13 13
14 """ 14 """
15 self.C = None 15 self.C = None
16 self.L = None 16 self.L = None
17 self.K = None 17 self.K = None
18 self.constrained = constrained 18 self.constrained = constrained
19 19
20 def predict(self, features): 20 def predict(self, features):
21 """ 21 """
22 22
23 @param features: 23 @param features:
24 @return: 24 @return:
25 """ 25 """
26 N = features.shape[0] 26 N = features.shape[0]
27 distances = np.zeros((N, self.K)) 27 distances = np.zeros((N, self.K))
28 for n in range(N): 28 for n in range(N):
29 for k in range(self.K): 29 for k in range(self.K):
30 distances[n][k] = self._dist(features[n], self.C[k], self.L[k]) 30 distances[n][k] = self._dist(features[n], self.C[k], self.L[k])
31 closest_cluster = np.argmin(distances, axis=1) 31 closest_cluster = np.argmin(distances, axis=1)
32 return closest_cluster 32 return closest_cluster
33 33
34 def load(self, model_path): 34 def load(self, model_path):
35 """ 35 """
36 36
37 @param model_path: 37 @param model_path:
38 @return: 38 @return:
39 """ 39 """
40 data = None 40 data = None
41 with open(model_path, "rb") as f: 41 with open(model_path, "rb") as f:
42 data = pickle.load(f) 42 data = pickle.load(f)
43 if data is None: 43 if data is None:
44 raise Exception("Le modèle n'a pas pu être chargé") 44 raise Exception("Le modèle n'a pas pu être chargé")
45 else: 45 else:
46 self.C = data["C"] 46 self.C = data["C"]
47 self.L = data["L"] 47 self.L = data["L"]
48 self.K = data["K"] 48 self.K = data["K"]
49 self.constrained = data["constrained"] 49 self.constrained = data["constrained"]
50 50
51 def save(self, modelpath: str): 51 def save(self, modelpath: str):
52 """ 52 """
53 53
54 @param modelpath: 54 @param modelpath:
55 @return: 55 @return:
56 """ 56 """
57 data = { 57 data = {
58 "C": self.C, 58 "C": self.C,
59 "L": self.L, 59 "L": self.L,
60 "K": self.K, 60 "K": self.K,
61 "constrained": self.constrained 61 "constrained": self.constrained
62 } 62 }
63 with open(modelpath, "wb") as f: 63 with open(modelpath, "wb") as f:
64 pickle.dump(data, f) 64 pickle.dump(data, f)
65 65
66 def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): 66 def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False):
67 results = [] 67 results = []
68 for i in range(ninit): 68 for i in range(ninit):
69 results.append(self._train(features, k, tol, maxiter, debug)) 69 results.append(self._train(features, k, tol, maxiter, debug))
70 losses = [v["loss"] for v in results] 70 losses = [v["loss"] for v in results]
71 best = results[losses.index(min(losses))] 71 best = results[losses.index(min(losses))]
72 if debug: 72 if debug:
73 print(f"best: {best['loss']} loss") 73 print(f"best: {best['loss']} loss")
74 self.C = best["C"] 74 self.C = best["C"]
75 self.L = best["L"] 75 self.L = best["L"]
76 self.K = best["K"] 76 self.K = best["K"]
77 77
78 def _initialize_model(self, X, number_clusters): 78 def _initialize_model(self, X, number_clusters):
79 d = X.shape[1] 79 d = X.shape[1]
80 C = X[np.random.choice(X.shape[0], number_clusters)] 80 C = X[np.random.choice(X.shape[0], number_clusters)]
81 L = np.zeros((number_clusters, d, d)) 81 L = np.zeros((number_clusters, d, d))
82 for k in range(number_clusters): 82 for k in range(number_clusters):
83 L[k] = np.identity(d) 83 L[k] = np.identity(d)
84 return C, L 84 return C, L
85 85
86 def _dist(self, a, b, l): 86 def _dist(self, a, b, l):
87 ''' 87 '''
88 Distance euclidienne with mahalanobis 88 Distance euclidienne with mahalanobis
89 ''' 89 '''
90 a = np.reshape(a, (-1, 1)) 90 a = np.reshape(a, (-1, 1))
91 b = np.reshape(b, (-1, 1)) 91 b = np.reshape(b, (-1, 1))
92 result = np.transpose(a - b).dot(l).dot(a - b)[0][0] 92 result = np.transpose(a - b).dot(l).dot(a - b)[0][0]
93 return result 93 return result
94 94
95 def _plot_iteration(self, iteration, points, clusters, centers): 95 def _plot_iteration(self, iteration, points, clusters, centers):
96 fig = plt.figure() 96 fig = plt.figure()
97 ax = fig.add_subplot(111) 97 ax = fig.add_subplot(111)
98 scatter = ax.scatter(points[:, 0], points[:, 1], c=clusters, s=50) 98 scatter = ax.scatter(points[:, 0], points[:, 1], c=clusters, s=50)
99
100 #for center in centers:
101 # ax.scatter(center[0], center[1], s=50, c='red', marker='+')
102 ax.scatter(centers[:, 0], centers[:, 1], s=50, c='red', marker='+') 99 ax.scatter(centers[:, 0], centers[:, 1], s=50, c='red', marker='+')
103
104 ax.set_xlabel('x') 100 ax.set_xlabel('x')
105 ax.set_ylabel('y') 101 ax.set_ylabel('y')
106 plt.colorbar(scatter) 102 plt.colorbar(scatter)
107 #plt.ylim(0, 1)
108 #plt.xlim(0, 1)
109 plt.savefig("test_" + str(iteration) + ".pdf") 103 plt.savefig("test_" + str(iteration) + ".pdf")
110 104
111 def _train(self, features, K: int, tol: float, maxiter: int, debug: bool=False): 105 def _train(self, features, K: int, tol: float, maxiter: int, debug: bool=False):
112 X = features 106 X = features
113 N = X.shape[0] 107 N = X.shape[0]
114 d = X.shape[1] 108 d = X.shape[1]
115 109
116 C, L = self._initialize_model(X, K) 110 C, L = self._initialize_model(X, K)
117 self.C = C 111 self.C = C
118 self.L = L 112 self.L = L
119 self.K = K 113 self.K = K
120 114
121 end_algo = False 115 end_algo = False
122 i = 0 116 i = 0
123 while not end_algo: 117 while not end_algo:
124 if debug: 118 if debug:
125 print("Iteration: ", i) 119 print("Iteration: ", i)
126 120
127 # Calcul matrix distance 121 # Calcul matrix distance
128 distances = np.zeros((N, self.K)) 122 distances = np.zeros((N, self.K))
129 123
130 for n in range(N): 124 for n in range(N):
131 for k in range(self.K): 125 for k in range(self.K):
132 distances[n][k] = self._dist(X[n], self.C[k], self.L[k]) 126 distances[n][k] = self._dist(X[n], self.C[k], self.L[k])
133 127
134 closest_cluster = np.argmin(distances, axis=1) 128 closest_cluster = np.argmin(distances, axis=1)
135 129
136 loss = np.sum(distances[np.arange(len(distances)), closest_cluster]) 130 loss = np.sum(distances[np.arange(len(distances)), closest_cluster])
137 if debug: 131 if debug:
138 print(f"loss {loss}") 132 print(f"loss {loss}")
139 133
140 134
141 # -- Debug tool ---------------------- 135 # -- Debug tool ----------------------
142 if debug and i % 1 == 0: 136 if debug and i % 1 == 0:
143 # TSNE if needed 137 # TSNE if needed
144 X_embedded = np.concatenate((X, self.C), axis=0) 138 X_embedded = np.concatenate((X, self.C), axis=0)
145 if d > 2: 139 if d > 2:
146 X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, self.C), axis=0)) 140 X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, self.C), axis=0))
147 141
148 # Then plot 142 # Then plot
149 self._plot_iteration( 143 self._plot_iteration(
150 i, 144 i,
151 X_embedded[:X.shape[0]], 145 X_embedded[:X.shape[0]],
152 closest_cluster, 146 closest_cluster,
153 X_embedded[X.shape[0]:] 147 X_embedded[X.shape[0]:]
154 ) 148 )
155 # ------------------------------------ 149 # ------------------------------------
156 150
157 old_c = self.C.copy() 151 old_c = self.C.copy()
158 for k in range(self.K): 152 for k in range(self.K):
159 # Find subset of X with values closed to the centroid c_k. 153 # Find subset of X with values closed to the centroid c_k.
160 X_sub = np.where(closest_cluster == k) 154 X_sub = np.where(closest_cluster == k)
161 X_sub = np.take(X, X_sub[0], axis=0) 155 X_sub = np.take(X, X_sub[0], axis=0)
162 if X_sub.shape[0] == 0: 156 if X_sub.shape[0] == 0:
163 continue 157 continue
164 158
165 C_new = np.mean(X_sub, axis=0) 159 C_new = np.mean(X_sub, axis=0)
166 160
167 # -- COMPUTE NEW LAMBDA (here named K) -- 161 # -- COMPUTE NEW LAMBDA (here named K) --
168 K_new = np.zeros((self.L.shape[1], self.L.shape[2])) 162 K_new = np.zeros((self.L.shape[1], self.L.shape[2]))
169 tmp = np.zeros((self.L.shape[1], self.L.shape[2])) 163 tmp = np.zeros((self.L.shape[1], self.L.shape[2]))
170 for x in X_sub: 164 for x in X_sub:
171 x = np.reshape(x, (-1, 1)) 165 x = np.reshape(x, (-1, 1))
172 c_tmp = np.reshape(C_new, (-1, 1)) 166 c_tmp = np.reshape(C_new, (-1, 1))
173 #K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose()) 167 #K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose())
174 168
175 tmp = tmp + (x - c_tmp).dot((x - c_tmp).transpose()) 169 tmp = tmp + (x - c_tmp).dot((x - c_tmp).transpose())
176 if self.constrained: 170 if self.constrained:
177 K_new = (tmp / X_sub.shape[0]) / np.power(np.linalg.det((tmp / X_sub.shape[0])), 1/d) 171 K_new = (tmp / X_sub.shape[0]) / np.power(np.linalg.det((tmp / X_sub.shape[0])), 1/d)
178 else: 172 else:
179 K_new = tmp / X_sub.shape[0] 173 K_new = tmp / X_sub.shape[0]
180 K_new = np.linalg.pinv(K_new) 174 K_new = np.linalg.pinv(K_new)
181 175
182 #if end_algo and (not (self.C[k] == C_new).all()): # If the same stop 176 #if end_algo and (not (self.C[k] == C_new).all()): # If the same stop
183 # end_algo = False 177 # end_algo = False
184 self.C[k] = C_new 178 self.C[k] = C_new
185 self.L[k] = K_new 179 self.L[k] = K_new
186 180
187 181
188 diff = np.sum(np.absolute((self.C - old_c) / old_c * 100)) 182 diff = np.sum(np.absolute((self.C - old_c) / old_c * 100))
189 if diff > tol: 183 if diff > tol:
190 end_algo = False 184 end_algo = False
191 if debug: 185 if debug:
192 print(f"{diff}") 186 print(f"{diff}")
193 else: 187 else:
194 if debug: 188 if debug:
195 print(f"Tolerance threshold {tol} reached with diff {diff}") 189 print(f"Tolerance threshold {tol} reached with diff {diff}")
196 end_algo = True 190 end_algo = True
197 191
198 i = i + 1 192 i = i + 1
199 if i > maxiter: 193 if i > maxiter:
200 end_algo = True 194 end_algo = True
201 if debug: 195 if debug:
202 print(f"Iteration {maxiter} reached") 196 print(f"Iteration {maxiter} reached")
203 return { 197 return {
204 "loss": loss, 198 "loss": loss,
205 "C": self.C, 199 "C": self.C,
206 "K": self.K, 200 "K": self.K,
207 "L": self.L 201 "L": self.L
208 } 202 }