Commit 55bcf758f37969188cf737d1a6ff23fe803ebddd
1 parent
8df8f5b238
Exists in
master
By default, kmeans mahalanobis is with constrains
Showing 1 changed file with 0 additions and 6 deletions Inline Diff
volia/clustering_modules/kmeans_mahalanobis.py
1 | 1 | ||
2 | 2 | ||
3 | from sklearn.cluster import KMeans | 3 | from sklearn.cluster import KMeans |
4 | import pickle | 4 | import pickle |
5 | import numpy as np | 5 | import numpy as np |
6 | import matplotlib.pyplot as plt | 6 | import matplotlib.pyplot as plt |
7 | from sklearn.manifold import TSNE | 7 | from sklearn.manifold import TSNE |
8 | from abstract_clustering import AbstractClustering | 8 | from abstract_clustering import AbstractClustering |
9 | 9 | ||
10 | class kmeansMahalanobis(): | 10 | class kmeansMahalanobis(): |
11 | def __init__(self, constrained: bool = False): | 11 | def __init__(self, constrained: bool = False): |
12 | """ | 12 | """ |
13 | 13 | ||
14 | """ | 14 | """ |
15 | self.C = None | 15 | self.C = None |
16 | self.L = None | 16 | self.L = None |
17 | self.K = None | 17 | self.K = None |
18 | self.constrained = constrained | 18 | self.constrained = constrained |
19 | 19 | ||
20 | def predict(self, features): | 20 | def predict(self, features): |
21 | """ | 21 | """ |
22 | 22 | ||
23 | @param features: | 23 | @param features: |
24 | @return: | 24 | @return: |
25 | """ | 25 | """ |
26 | N = features.shape[0] | 26 | N = features.shape[0] |
27 | distances = np.zeros((N, self.K)) | 27 | distances = np.zeros((N, self.K)) |
28 | for n in range(N): | 28 | for n in range(N): |
29 | for k in range(self.K): | 29 | for k in range(self.K): |
30 | distances[n][k] = self._dist(features[n], self.C[k], self.L[k]) | 30 | distances[n][k] = self._dist(features[n], self.C[k], self.L[k]) |
31 | closest_cluster = np.argmin(distances, axis=1) | 31 | closest_cluster = np.argmin(distances, axis=1) |
32 | return closest_cluster | 32 | return closest_cluster |
33 | 33 | ||
34 | def load(self, model_path): | 34 | def load(self, model_path): |
35 | """ | 35 | """ |
36 | 36 | ||
37 | @param model_path: | 37 | @param model_path: |
38 | @return: | 38 | @return: |
39 | """ | 39 | """ |
40 | data = None | 40 | data = None |
41 | with open(model_path, "rb") as f: | 41 | with open(model_path, "rb") as f: |
42 | data = pickle.load(f) | 42 | data = pickle.load(f) |
43 | if data is None: | 43 | if data is None: |
44 | raise Exception("Le modèle n'a pas pu être chargé") | 44 | raise Exception("Le modèle n'a pas pu être chargé") |
45 | else: | 45 | else: |
46 | self.C = data["C"] | 46 | self.C = data["C"] |
47 | self.L = data["L"] | 47 | self.L = data["L"] |
48 | self.K = data["K"] | 48 | self.K = data["K"] |
49 | self.constrained = data["constrained"] | 49 | self.constrained = data["constrained"] |
50 | 50 | ||
51 | def save(self, modelpath: str): | 51 | def save(self, modelpath: str): |
52 | """ | 52 | """ |
53 | 53 | ||
54 | @param modelpath: | 54 | @param modelpath: |
55 | @return: | 55 | @return: |
56 | """ | 56 | """ |
57 | data = { | 57 | data = { |
58 | "C": self.C, | 58 | "C": self.C, |
59 | "L": self.L, | 59 | "L": self.L, |
60 | "K": self.K, | 60 | "K": self.K, |
61 | "constrained": self.constrained | 61 | "constrained": self.constrained |
62 | } | 62 | } |
63 | with open(modelpath, "wb") as f: | 63 | with open(modelpath, "wb") as f: |
64 | pickle.dump(data, f) | 64 | pickle.dump(data, f) |
65 | 65 | ||
66 | def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): | 66 | def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): |
67 | results = [] | 67 | results = [] |
68 | for i in range(ninit): | 68 | for i in range(ninit): |
69 | results.append(self._train(features, k, tol, maxiter, debug)) | 69 | results.append(self._train(features, k, tol, maxiter, debug)) |
70 | losses = [v["loss"] for v in results] | 70 | losses = [v["loss"] for v in results] |
71 | best = results[losses.index(min(losses))] | 71 | best = results[losses.index(min(losses))] |
72 | if debug: | 72 | if debug: |
73 | print(f"best: {best['loss']} loss") | 73 | print(f"best: {best['loss']} loss") |
74 | self.C = best["C"] | 74 | self.C = best["C"] |
75 | self.L = best["L"] | 75 | self.L = best["L"] |
76 | self.K = best["K"] | 76 | self.K = best["K"] |
77 | 77 | ||
78 | def _initialize_model(self, X, number_clusters): | 78 | def _initialize_model(self, X, number_clusters): |
79 | d = X.shape[1] | 79 | d = X.shape[1] |
80 | C = X[np.random.choice(X.shape[0], number_clusters)] | 80 | C = X[np.random.choice(X.shape[0], number_clusters)] |
81 | L = np.zeros((number_clusters, d, d)) | 81 | L = np.zeros((number_clusters, d, d)) |
82 | for k in range(number_clusters): | 82 | for k in range(number_clusters): |
83 | L[k] = np.identity(d) | 83 | L[k] = np.identity(d) |
84 | return C, L | 84 | return C, L |
85 | 85 | ||
86 | def _dist(self, a, b, l): | 86 | def _dist(self, a, b, l): |
87 | ''' | 87 | ''' |
88 | Distance euclidienne with mahalanobis | 88 | Distance euclidienne with mahalanobis |
89 | ''' | 89 | ''' |
90 | a = np.reshape(a, (-1, 1)) | 90 | a = np.reshape(a, (-1, 1)) |
91 | b = np.reshape(b, (-1, 1)) | 91 | b = np.reshape(b, (-1, 1)) |
92 | result = np.transpose(a - b).dot(l).dot(a - b)[0][0] | 92 | result = np.transpose(a - b).dot(l).dot(a - b)[0][0] |
93 | return result | 93 | return result |
94 | 94 | ||
95 | def _plot_iteration(self, iteration, points, clusters, centers): | 95 | def _plot_iteration(self, iteration, points, clusters, centers): |
96 | fig = plt.figure() | 96 | fig = plt.figure() |
97 | ax = fig.add_subplot(111) | 97 | ax = fig.add_subplot(111) |
98 | scatter = ax.scatter(points[:, 0], points[:, 1], c=clusters, s=50) | 98 | scatter = ax.scatter(points[:, 0], points[:, 1], c=clusters, s=50) |
99 | |||
100 | #for center in centers: | ||
101 | # ax.scatter(center[0], center[1], s=50, c='red', marker='+') | ||
102 | ax.scatter(centers[:, 0], centers[:, 1], s=50, c='red', marker='+') | 99 | ax.scatter(centers[:, 0], centers[:, 1], s=50, c='red', marker='+') |
103 | |||
104 | ax.set_xlabel('x') | 100 | ax.set_xlabel('x') |
105 | ax.set_ylabel('y') | 101 | ax.set_ylabel('y') |
106 | plt.colorbar(scatter) | 102 | plt.colorbar(scatter) |
107 | #plt.ylim(0, 1) | ||
108 | #plt.xlim(0, 1) | ||
109 | plt.savefig("test_" + str(iteration) + ".pdf") | 103 | plt.savefig("test_" + str(iteration) + ".pdf") |
110 | 104 | ||
111 | def _train(self, features, K: int, tol: float, maxiter: int, debug: bool=False): | 105 | def _train(self, features, K: int, tol: float, maxiter: int, debug: bool=False): |
112 | X = features | 106 | X = features |
113 | N = X.shape[0] | 107 | N = X.shape[0] |
114 | d = X.shape[1] | 108 | d = X.shape[1] |
115 | 109 | ||
116 | C, L = self._initialize_model(X, K) | 110 | C, L = self._initialize_model(X, K) |
117 | self.C = C | 111 | self.C = C |
118 | self.L = L | 112 | self.L = L |
119 | self.K = K | 113 | self.K = K |
120 | 114 | ||
121 | end_algo = False | 115 | end_algo = False |
122 | i = 0 | 116 | i = 0 |
123 | while not end_algo: | 117 | while not end_algo: |
124 | if debug: | 118 | if debug: |
125 | print("Iteration: ", i) | 119 | print("Iteration: ", i) |
126 | 120 | ||
127 | # Calcul matrix distance | 121 | # Calcul matrix distance |
128 | distances = np.zeros((N, self.K)) | 122 | distances = np.zeros((N, self.K)) |
129 | 123 | ||
130 | for n in range(N): | 124 | for n in range(N): |
131 | for k in range(self.K): | 125 | for k in range(self.K): |
132 | distances[n][k] = self._dist(X[n], self.C[k], self.L[k]) | 126 | distances[n][k] = self._dist(X[n], self.C[k], self.L[k]) |
133 | 127 | ||
134 | closest_cluster = np.argmin(distances, axis=1) | 128 | closest_cluster = np.argmin(distances, axis=1) |
135 | 129 | ||
136 | loss = np.sum(distances[np.arange(len(distances)), closest_cluster]) | 130 | loss = np.sum(distances[np.arange(len(distances)), closest_cluster]) |
137 | if debug: | 131 | if debug: |
138 | print(f"loss {loss}") | 132 | print(f"loss {loss}") |
139 | 133 | ||
140 | 134 | ||
141 | # -- Debug tool ---------------------- | 135 | # -- Debug tool ---------------------- |
142 | if debug and i % 1 == 0: | 136 | if debug and i % 1 == 0: |
143 | # TSNE if needed | 137 | # TSNE if needed |
144 | X_embedded = np.concatenate((X, self.C), axis=0) | 138 | X_embedded = np.concatenate((X, self.C), axis=0) |
145 | if d > 2: | 139 | if d > 2: |
146 | X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, self.C), axis=0)) | 140 | X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, self.C), axis=0)) |
147 | 141 | ||
148 | # Then plot | 142 | # Then plot |
149 | self._plot_iteration( | 143 | self._plot_iteration( |
150 | i, | 144 | i, |
151 | X_embedded[:X.shape[0]], | 145 | X_embedded[:X.shape[0]], |
152 | closest_cluster, | 146 | closest_cluster, |
153 | X_embedded[X.shape[0]:] | 147 | X_embedded[X.shape[0]:] |
154 | ) | 148 | ) |
155 | # ------------------------------------ | 149 | # ------------------------------------ |
156 | 150 | ||
157 | old_c = self.C.copy() | 151 | old_c = self.C.copy() |
158 | for k in range(self.K): | 152 | for k in range(self.K): |
159 | # Find subset of X with values closed to the centroid c_k. | 153 | # Find subset of X with values closed to the centroid c_k. |
160 | X_sub = np.where(closest_cluster == k) | 154 | X_sub = np.where(closest_cluster == k) |
161 | X_sub = np.take(X, X_sub[0], axis=0) | 155 | X_sub = np.take(X, X_sub[0], axis=0) |
162 | if X_sub.shape[0] == 0: | 156 | if X_sub.shape[0] == 0: |
163 | continue | 157 | continue |
164 | 158 | ||
165 | C_new = np.mean(X_sub, axis=0) | 159 | C_new = np.mean(X_sub, axis=0) |
166 | 160 | ||
167 | # -- COMPUTE NEW LAMBDA (here named K) -- | 161 | # -- COMPUTE NEW LAMBDA (here named K) -- |
168 | K_new = np.zeros((self.L.shape[1], self.L.shape[2])) | 162 | K_new = np.zeros((self.L.shape[1], self.L.shape[2])) |
169 | tmp = np.zeros((self.L.shape[1], self.L.shape[2])) | 163 | tmp = np.zeros((self.L.shape[1], self.L.shape[2])) |
170 | for x in X_sub: | 164 | for x in X_sub: |
171 | x = np.reshape(x, (-1, 1)) | 165 | x = np.reshape(x, (-1, 1)) |
172 | c_tmp = np.reshape(C_new, (-1, 1)) | 166 | c_tmp = np.reshape(C_new, (-1, 1)) |
173 | #K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose()) | 167 | #K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose()) |
174 | 168 | ||
175 | tmp = tmp + (x - c_tmp).dot((x - c_tmp).transpose()) | 169 | tmp = tmp + (x - c_tmp).dot((x - c_tmp).transpose()) |
176 | if self.constrained: | 170 | if self.constrained: |
177 | K_new = (tmp / X_sub.shape[0]) / np.power(np.linalg.det((tmp / X_sub.shape[0])), 1/d) | 171 | K_new = (tmp / X_sub.shape[0]) / np.power(np.linalg.det((tmp / X_sub.shape[0])), 1/d) |
178 | else: | 172 | else: |
179 | K_new = tmp / X_sub.shape[0] | 173 | K_new = tmp / X_sub.shape[0] |
180 | K_new = np.linalg.pinv(K_new) | 174 | K_new = np.linalg.pinv(K_new) |
181 | 175 | ||
182 | #if end_algo and (not (self.C[k] == C_new).all()): # If the same stop | 176 | #if end_algo and (not (self.C[k] == C_new).all()): # If the same stop |
183 | # end_algo = False | 177 | # end_algo = False |
184 | self.C[k] = C_new | 178 | self.C[k] = C_new |
185 | self.L[k] = K_new | 179 | self.L[k] = K_new |
186 | 180 | ||
187 | 181 | ||
188 | diff = np.sum(np.absolute((self.C - old_c) / old_c * 100)) | 182 | diff = np.sum(np.absolute((self.C - old_c) / old_c * 100)) |
189 | if diff > tol: | 183 | if diff > tol: |
190 | end_algo = False | 184 | end_algo = False |
191 | if debug: | 185 | if debug: |
192 | print(f"{diff}") | 186 | print(f"{diff}") |
193 | else: | 187 | else: |
194 | if debug: | 188 | if debug: |
195 | print(f"Tolerance threshold {tol} reached with diff {diff}") | 189 | print(f"Tolerance threshold {tol} reached with diff {diff}") |
196 | end_algo = True | 190 | end_algo = True |
197 | 191 | ||
198 | i = i + 1 | 192 | i = i + 1 |
199 | if i > maxiter: | 193 | if i > maxiter: |
200 | end_algo = True | 194 | end_algo = True |
201 | if debug: | 195 | if debug: |
202 | print(f"Iteration {maxiter} reached") | 196 | print(f"Iteration {maxiter} reached") |
203 | return { | 197 | return { |
204 | "loss": loss, | 198 | "loss": loss, |
205 | "C": self.C, | 199 | "C": self.C, |
206 | "K": self.K, | 200 | "K": self.K, |
207 | "L": self.L | 201 | "L": self.L |
208 | } | 202 | } |