Commit 4309b4a340144c6dff7757892cd5539e89e20538
1 parent
88d1d67e9d
Exists in
master
Adding constrained mahalanobis to help converging
Showing 2 changed files with 23 additions and 12 deletions Side-by-side Diff
volia/clustering.py
... | ... | @@ -17,7 +17,8 @@ |
17 | 17 | |
18 | 18 | CLUSTERING_METHODS = { |
19 | 19 | "k-means": kmeans(), |
20 | - "k-means-mahalanobis": kmeansMahalanobis() | |
20 | + "k-means-mahalanobis": kmeansMahalanobis(), | |
21 | + "k-means-mahalanobis-constrained": kmeansMahalanobis(constrained=True) | |
21 | 22 | } |
22 | 23 | |
23 | 24 | EVALUATION_METHODS = { |
volia/clustering_modules/kmeans_mahalanobis.py
... | ... | @@ -8,13 +8,14 @@ |
8 | 8 | from abstract_clustering import AbstractClustering |
9 | 9 | |
10 | 10 | class kmeansMahalanobis(): |
11 | - def __init__(self): | |
11 | + def __init__(self, constrained: bool = False): | |
12 | 12 | """ |
13 | 13 | |
14 | 14 | """ |
15 | 15 | self.C = None |
16 | 16 | self.L = None |
17 | 17 | self.K = None |
18 | + self.constrained = constrained | |
18 | 19 | |
19 | 20 | def predict(self, features): |
20 | 21 | """ |
... | ... | @@ -45,6 +46,7 @@ |
45 | 46 | self.C = data["C"] |
46 | 47 | self.L = data["L"] |
47 | 48 | self.K = data["K"] |
49 | + self.constrained = data["constrained"] | |
48 | 50 | |
49 | 51 | def save(self, modelpath: str): |
50 | 52 | """ |
... | ... | @@ -55,7 +57,8 @@ |
55 | 57 | data = { |
56 | 58 | "C": self.C, |
57 | 59 | "L": self.L, |
58 | - "K": self.K | |
60 | + "K": self.K, | |
61 | + "constrained": self.constrained | |
59 | 62 | } |
60 | 63 | with open(modelpath, "wb") as f: |
61 | 64 | pickle.dump(data, f) |
62 | 65 | |
... | ... | @@ -82,11 +85,11 @@ |
82 | 85 | |
83 | 86 | def _dist(self, a, b, l): |
84 | 87 | ''' |
85 | - Distance euclidienne | |
88 | + Distance euclidienne with mahalanobis | |
86 | 89 | ''' |
87 | 90 | a = np.reshape(a, (-1, 1)) |
88 | 91 | b = np.reshape(b, (-1, 1)) |
89 | - result = np.transpose(a - b).dot(l).dot(a-b)[0][0] | |
92 | + result = np.transpose(a - b).dot(l).dot(a - b)[0][0] | |
90 | 93 | return result |
91 | 94 | |
92 | 95 | def _plot_iteration(self, iteration, points, clusters, centers): |
93 | 96 | |
94 | 97 | |
... | ... | @@ -129,17 +132,18 @@ |
129 | 132 | distances[n][k] = self._dist(X[n], self.C[k], self.L[k]) |
130 | 133 | |
131 | 134 | closest_cluster = np.argmin(distances, axis=1) |
135 | + | |
132 | 136 | loss = np.sum(distances[np.arange(len(distances)), closest_cluster]) |
133 | 137 | if debug: |
134 | 138 | print(f"loss {loss}") |
135 | 139 | |
136 | 140 | |
137 | 141 | # -- Debug tool ---------------------- |
138 | - if debug and i % 10 == 0: | |
142 | + if debug and i % 1 == 0: | |
139 | 143 | # TSNE if needed |
140 | 144 | X_embedded = np.concatenate((X, self.C), axis=0) |
141 | 145 | if d > 2: |
142 | - X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, C), axis=0)) | |
146 | + X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, self.C), axis=0)) | |
143 | 147 | |
144 | 148 | # Then plot |
145 | 149 | self._plot_iteration( |
146 | 150 | |
147 | 151 | |
148 | 152 | |
... | ... | @@ -151,22 +155,28 @@ |
151 | 155 | # ------------------------------------ |
152 | 156 | |
153 | 157 | old_c = self.C.copy() |
154 | - for k in range(K): | |
158 | + for k in range(self.K): | |
155 | 159 | # Find subset of X with values closed to the centroid c_k. |
156 | 160 | X_sub = np.where(closest_cluster == k) |
157 | 161 | X_sub = np.take(X, X_sub[0], axis=0) |
158 | 162 | if X_sub.shape[0] == 0: |
159 | 163 | continue |
160 | - np.mean(X_sub, axis=0) | |
164 | + | |
161 | 165 | C_new = np.mean(X_sub, axis=0) |
162 | 166 | |
163 | 167 | # -- COMPUTE NEW LAMBDA (here named K) -- |
164 | - K_new = np.zeros((L.shape[1], L.shape[2])) | |
168 | + K_new = np.zeros((self.L.shape[1], self.L.shape[2])) | |
169 | + tmp = np.zeros((self.L.shape[1], self.L.shape[2])) | |
165 | 170 | for x in X_sub: |
166 | 171 | x = np.reshape(x, (-1, 1)) |
167 | 172 | c_tmp = np.reshape(C_new, (-1, 1)) |
168 | - K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose()) | |
169 | - K_new = K_new / X_sub.shape[0] | |
173 | + #K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose()) | |
174 | + | |
175 | + tmp = tmp + (x - c_tmp).dot((x - c_tmp).transpose()) | |
176 | + if self.constrained: | |
177 | + K_new = (tmp / X_sub.shape[0]) / np.power(np.linalg.det((tmp / X_sub.shape[0])), 1/d) | |
178 | + else: | |
179 | + K_new = tmp / X_sub.shape[0] | |
170 | 180 | K_new = np.linalg.pinv(K_new) |
171 | 181 | |
172 | 182 | #if end_algo and (not (self.C[k] == C_new).all()): # If the same stop |