Blame view
volia/clustering_modules/kmeans_mahalanobis.py
6.22 KB
4152e83df
|
1 2 3 4 5 6 7 8 9 10 |
from sklearn.cluster import KMeans import pickle import numpy as np import matplotlib.pyplot as plt from sklearn.manifold import TSNE from abstract_clustering import AbstractClustering class kmeansMahalanobis(): |
4309b4a34
|
11 |
def __init__(self, constrained: bool = False): |
4152e83df
|
12 13 14 15 16 17 |
""" """ self.C = None self.L = None self.K = None |
4309b4a34
|
18 |
self.constrained = constrained |
4152e83df
|
19 20 21 22 23 24 25 26 27 28 29 30 |
def predict(self, features): """ @param features: @return: """ N = features.shape[0] distances = np.zeros((N, self.K)) for n in range(N): for k in range(self.K): distances[n][k] = self._dist(features[n], self.C[k], self.L[k]) |
4152e83df
|
31 32 33 34 35 36 37 38 39 40 |
closest_cluster = np.argmin(distances, axis=1) return closest_cluster def load(self, model_path): """ @param model_path: @return: """ data = None |
ed89325d5
|
41 42 |
with open(model_path, "rb") as f: data = pickle.load(f) |
4152e83df
|
43 44 45 46 47 48 |
if data is None: raise Exception("Le modèle n'a pas pu être chargé") else: self.C = data["C"] self.L = data["L"] self.K = data["K"] |
4309b4a34
|
49 |
self.constrained = data["constrained"] |
4152e83df
|
50 51 52 53 54 55 56 57 58 59 |
def save(self, modelpath: str): """ @param modelpath: @return: """ data = { "C": self.C, "L": self.L, |
4309b4a34
|
60 61 |
"K": self.K, "constrained": self.constrained |
4152e83df
|
62 63 64 |
} with open(modelpath, "wb") as f: pickle.dump(data, f) |
660d9960f
|
65 66 67 68 69 70 71 72 73 74 75 |
def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): results = [] for i in range(ninit): results.append(self._train(features, k, tol, maxiter, debug)) losses = [v["loss"] for v in results] best = results[losses.index(min(losses))] if debug: print(f"best: {best['loss']} loss") self.C = best["C"] self.L = best["L"] self.K = best["K"] |
4152e83df
|
76 77 78 79 80 81 82 83 84 85 86 |
def _initialize_model(self, X, number_clusters): d = X.shape[1] C = X[np.random.choice(X.shape[0], number_clusters)] L = np.zeros((number_clusters, d, d)) for k in range(number_clusters): L[k] = np.identity(d) return C, L def _dist(self, a, b, l): ''' |
4309b4a34
|
87 |
Distance euclidienne with mahalanobis |
4152e83df
|
88 89 90 |
''' a = np.reshape(a, (-1, 1)) b = np.reshape(b, (-1, 1)) |
4309b4a34
|
91 |
result = np.transpose(a - b).dot(l).dot(a - b)[0][0] |
4152e83df
|
92 93 94 95 96 97 |
return result def _plot_iteration(self, iteration, points, clusters, centers): fig = plt.figure() ax = fig.add_subplot(111) scatter = ax.scatter(points[:, 0], points[:, 1], c=clusters, s=50) |
4152e83df
|
98 |
ax.scatter(centers[:, 0], centers[:, 1], s=50, c='red', marker='+') |
4152e83df
|
99 100 101 |
ax.set_xlabel('x') ax.set_ylabel('y') plt.colorbar(scatter) |
4152e83df
|
102 |
plt.savefig("test_" + str(iteration) + ".pdf") |
ed89325d5
|
103 |
def _train(self, features, K: int, tol: float, maxiter: int, debug: bool=False): |
4152e83df
|
104 105 106 107 108 109 110 111 112 113 114 115 |
X = features N = X.shape[0] d = X.shape[1] C, L = self._initialize_model(X, K) self.C = C self.L = L self.K = K end_algo = False i = 0 while not end_algo: |
ed89325d5
|
116 117 |
if debug: print("Iteration: ", i) |
4152e83df
|
118 |
# Calcul matrix distance |
660d9960f
|
119 |
distances = np.zeros((N, self.K)) |
4152e83df
|
120 121 122 123 |
for n in range(N): for k in range(self.K): distances[n][k] = self._dist(X[n], self.C[k], self.L[k]) |
660d9960f
|
124 |
|
4152e83df
|
125 |
closest_cluster = np.argmin(distances, axis=1) |
4309b4a34
|
126 |
|
660d9960f
|
127 128 129 |
loss = np.sum(distances[np.arange(len(distances)), closest_cluster]) if debug: print(f"loss {loss}") |
ed89325d5
|
130 131 |
# -- Debug tool ---------------------- |
4309b4a34
|
132 |
if debug and i % 1 == 0: |
ed89325d5
|
133 134 135 |
# TSNE if needed X_embedded = np.concatenate((X, self.C), axis=0) if d > 2: |
4309b4a34
|
136 |
X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, self.C), axis=0)) |
ed89325d5
|
137 |
|
4152e83df
|
138 139 140 141 142 143 144 |
# Then plot self._plot_iteration( i, X_embedded[:X.shape[0]], closest_cluster, X_embedded[X.shape[0]:] ) |
ed89325d5
|
145 |
# ------------------------------------ |
4152e83df
|
146 |
|
ed89325d5
|
147 |
old_c = self.C.copy() |
4309b4a34
|
148 |
for k in range(self.K): |
4152e83df
|
149 150 151 |
# Find subset of X with values closed to the centroid c_k. X_sub = np.where(closest_cluster == k) X_sub = np.take(X, X_sub[0], axis=0) |
d4507c268
|
152 153 |
if X_sub.shape[0] == 0: continue |
4309b4a34
|
154 |
|
4152e83df
|
155 156 157 |
C_new = np.mean(X_sub, axis=0) # -- COMPUTE NEW LAMBDA (here named K) -- |
4309b4a34
|
158 159 |
K_new = np.zeros((self.L.shape[1], self.L.shape[2])) tmp = np.zeros((self.L.shape[1], self.L.shape[2])) |
4152e83df
|
160 161 162 |
for x in X_sub: x = np.reshape(x, (-1, 1)) c_tmp = np.reshape(C_new, (-1, 1)) |
4309b4a34
|
163 164 165 166 167 168 169 |
#K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose()) tmp = tmp + (x - c_tmp).dot((x - c_tmp).transpose()) if self.constrained: K_new = (tmp / X_sub.shape[0]) / np.power(np.linalg.det((tmp / X_sub.shape[0])), 1/d) else: K_new = tmp / X_sub.shape[0] |
d4507c268
|
170 |
K_new = np.linalg.pinv(K_new) |
4152e83df
|
171 |
|
ed89325d5
|
172 173 |
#if end_algo and (not (self.C[k] == C_new).all()): # If the same stop # end_algo = False |
4152e83df
|
174 175 |
self.C[k] = C_new self.L[k] = K_new |
ed89325d5
|
176 |
|
660d9960f
|
177 |
|
ed89325d5
|
178 179 180 181 182 |
diff = np.sum(np.absolute((self.C - old_c) / old_c * 100)) if diff > tol: end_algo = False if debug: print(f"{diff}") |
660d9960f
|
183 184 185 |
else: if debug: print(f"Tolerance threshold {tol} reached with diff {diff}") |
ed89325d5
|
186 |
end_algo = True |
660d9960f
|
187 |
|
4152e83df
|
188 |
i = i + 1 |
ed89325d5
|
189 190 191 192 |
if i > maxiter: end_algo = True if debug: print(f"Iteration {maxiter} reached") |
660d9960f
|
193 194 195 196 197 198 |
return { "loss": loss, "C": self.C, "K": self.K, "L": self.L } |