Blame view

volia/clustering_modules/kmeans_mahalanobis.py 6.22 KB
4152e83df   quillotm   Addind kmeans mah...
1
2
3
4
5
6
7
8
9
10
  
  
  from sklearn.cluster import KMeans
  import pickle
  import numpy as np
  import matplotlib.pyplot as plt
  from sklearn.manifold import TSNE
  from abstract_clustering import AbstractClustering
  
  class kmeansMahalanobis():
4309b4a34   quillotm   Adding constraine...
11
      def __init__(self, constrained: bool = False):
4152e83df   quillotm   Addind kmeans mah...
12
13
14
15
16
17
          """
  
          """
          self.C = None
          self.L = None
          self.K = None
4309b4a34   quillotm   Adding constraine...
18
          self.constrained = constrained
4152e83df   quillotm   Addind kmeans mah...
19
20
21
22
23
24
25
26
27
28
29
30
  
      def predict(self, features):
          """
  
          @param features:
          @return:
          """
          N = features.shape[0]
          distances = np.zeros((N, self.K))
          for n in range(N):
              for k in range(self.K):
                  distances[n][k] = self._dist(features[n], self.C[k], self.L[k])
4152e83df   quillotm   Addind kmeans mah...
31
32
33
34
35
36
37
38
39
40
          closest_cluster = np.argmin(distances, axis=1)
          return closest_cluster
  
      def load(self, model_path):
          """
  
          @param model_path:
          @return:
          """
          data = None
ed89325d5   quillotm   Now, we can give ...
41
42
          with open(model_path, "rb") as f:
              data = pickle.load(f)
4152e83df   quillotm   Addind kmeans mah...
43
44
45
46
47
48
          if data is None:
              raise Exception("Le modèle n'a pas pu être chargé")
          else:
              self.C = data["C"]
              self.L = data["L"]
              self.K = data["K"]
4309b4a34   quillotm   Adding constraine...
49
              self.constrained = data["constrained"]
4152e83df   quillotm   Addind kmeans mah...
50
51
52
53
54
55
56
57
58
59
  
      def save(self, modelpath: str):
          """
  
          @param modelpath:
          @return:
          """
          data = {
              "C": self.C,
              "L": self.L,
4309b4a34   quillotm   Adding constraine...
60
61
              "K": self.K,
              "constrained": self.constrained
4152e83df   quillotm   Addind kmeans mah...
62
63
64
          }
          with open(modelpath, "wb") as f:
              pickle.dump(data, f)
660d9960f   quillotm   Adding n init par...
65
66
67
68
69
70
71
72
73
74
75
      def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False):
          results = []
          for i in range(ninit):
              results.append(self._train(features, k, tol, maxiter, debug))
          losses = [v["loss"] for v in results]
          best = results[losses.index(min(losses))]
          if debug:
              print(f"best: {best['loss']} loss")
          self.C = best["C"]
          self.L = best["L"]
          self.K = best["K"]
4152e83df   quillotm   Addind kmeans mah...
76
77
78
79
80
81
82
83
84
85
86
  
      def _initialize_model(self, X, number_clusters):
          d = X.shape[1]
          C = X[np.random.choice(X.shape[0], number_clusters)]
          L = np.zeros((number_clusters, d, d))
          for k in range(number_clusters):
              L[k] = np.identity(d)
          return C, L
  
      def _dist(self, a, b, l):
          '''
4309b4a34   quillotm   Adding constraine...
87
          Distance euclidienne with mahalanobis
4152e83df   quillotm   Addind kmeans mah...
88
89
90
          '''
          a = np.reshape(a, (-1, 1))
          b = np.reshape(b, (-1, 1))
4309b4a34   quillotm   Adding constraine...
91
          result = np.transpose(a - b).dot(l).dot(a - b)[0][0]
4152e83df   quillotm   Addind kmeans mah...
92
93
94
95
96
97
          return result
  
      def _plot_iteration(self, iteration, points, clusters, centers):
          fig = plt.figure()
          ax = fig.add_subplot(111)
          scatter = ax.scatter(points[:, 0], points[:, 1], c=clusters, s=50)
4152e83df   quillotm   Addind kmeans mah...
98
          ax.scatter(centers[:, 0], centers[:, 1], s=50, c='red', marker='+')
4152e83df   quillotm   Addind kmeans mah...
99
100
101
          ax.set_xlabel('x')
          ax.set_ylabel('y')
          plt.colorbar(scatter)
4152e83df   quillotm   Addind kmeans mah...
102
          plt.savefig("test_" + str(iteration) + ".pdf")
ed89325d5   quillotm   Now, we can give ...
103
      def _train(self, features, K: int, tol: float, maxiter: int, debug: bool=False):
4152e83df   quillotm   Addind kmeans mah...
104
105
106
107
108
109
110
111
112
113
114
115
          X = features
          N = X.shape[0]
          d = X.shape[1]
  
          C, L = self._initialize_model(X, K)
          self.C = C
          self.L = L
          self.K = K
  
          end_algo = False
          i = 0
          while not end_algo:
ed89325d5   quillotm   Now, we can give ...
116
117
              if debug:
                  print("Iteration: ", i)
4152e83df   quillotm   Addind kmeans mah...
118
              # Calcul matrix distance
660d9960f   quillotm   Adding n init par...
119
              distances = np.zeros((N, self.K))
4152e83df   quillotm   Addind kmeans mah...
120
121
122
123
  
              for n in range(N):
                  for k in range(self.K):
                      distances[n][k] = self._dist(X[n], self.C[k], self.L[k])
660d9960f   quillotm   Adding n init par...
124

4152e83df   quillotm   Addind kmeans mah...
125
              closest_cluster = np.argmin(distances, axis=1)
4309b4a34   quillotm   Adding constraine...
126

660d9960f   quillotm   Adding n init par...
127
128
129
              loss = np.sum(distances[np.arange(len(distances)), closest_cluster])
              if debug:
                  print(f"loss {loss}")
ed89325d5   quillotm   Now, we can give ...
130
131
  
              # -- Debug tool ----------------------
4309b4a34   quillotm   Adding constraine...
132
              if debug and i % 1 == 0:
ed89325d5   quillotm   Now, we can give ...
133
134
135
                  # TSNE if needed
                  X_embedded = np.concatenate((X, self.C), axis=0)
                  if d > 2:
4309b4a34   quillotm   Adding constraine...
136
                      X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, self.C), axis=0))
ed89325d5   quillotm   Now, we can give ...
137

4152e83df   quillotm   Addind kmeans mah...
138
139
140
141
142
143
144
                  # Then plot
                  self._plot_iteration(
                      i,
                      X_embedded[:X.shape[0]],
                      closest_cluster,
                      X_embedded[X.shape[0]:]
                  )
ed89325d5   quillotm   Now, we can give ...
145
              # ------------------------------------
4152e83df   quillotm   Addind kmeans mah...
146

ed89325d5   quillotm   Now, we can give ...
147
              old_c = self.C.copy()
4309b4a34   quillotm   Adding constraine...
148
              for k in range(self.K):
4152e83df   quillotm   Addind kmeans mah...
149
150
151
                  # Find subset of X with values closed to the centroid c_k.
                  X_sub = np.where(closest_cluster == k)
                  X_sub = np.take(X, X_sub[0], axis=0)
d4507c268   quillotm   We do not update ...
152
153
                  if X_sub.shape[0] == 0:
                      continue
4309b4a34   quillotm   Adding constraine...
154

4152e83df   quillotm   Addind kmeans mah...
155
156
157
                  C_new = np.mean(X_sub, axis=0)
  
                  # -- COMPUTE NEW LAMBDA (here named K) --
4309b4a34   quillotm   Adding constraine...
158
159
                  K_new = np.zeros((self.L.shape[1], self.L.shape[2]))
                  tmp = np.zeros((self.L.shape[1], self.L.shape[2]))
4152e83df   quillotm   Addind kmeans mah...
160
161
162
                  for x in X_sub:
                      x = np.reshape(x, (-1, 1))
                      c_tmp = np.reshape(C_new, (-1, 1))
4309b4a34   quillotm   Adding constraine...
163
164
165
166
167
168
169
                      #K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose())
  
                      tmp = tmp + (x - c_tmp).dot((x - c_tmp).transpose())
                  if self.constrained:
                      K_new = (tmp / X_sub.shape[0]) / np.power(np.linalg.det((tmp / X_sub.shape[0])), 1/d)
                  else:
                      K_new = tmp / X_sub.shape[0]
d4507c268   quillotm   We do not update ...
170
                  K_new = np.linalg.pinv(K_new)
4152e83df   quillotm   Addind kmeans mah...
171

ed89325d5   quillotm   Now, we can give ...
172
173
                  #if end_algo and (not (self.C[k] == C_new).all()):  # If the same stop
                  #    end_algo = False
4152e83df   quillotm   Addind kmeans mah...
174
175
                  self.C[k] = C_new
                  self.L[k] = K_new
ed89325d5   quillotm   Now, we can give ...
176

660d9960f   quillotm   Adding n init par...
177

ed89325d5   quillotm   Now, we can give ...
178
179
180
181
182
              diff = np.sum(np.absolute((self.C - old_c) / old_c * 100))
              if diff > tol:
                  end_algo = False
                  if debug:
                      print(f"{diff}")
660d9960f   quillotm   Adding n init par...
183
184
185
              else:
                  if debug:
                      print(f"Tolerance threshold {tol} reached with diff {diff}")
ed89325d5   quillotm   Now, we can give ...
186
                  end_algo = True
660d9960f   quillotm   Adding n init par...
187

4152e83df   quillotm   Addind kmeans mah...
188
              i = i + 1
ed89325d5   quillotm   Now, we can give ...
189
190
191
192
              if i > maxiter:
                  end_algo = True
                  if debug:
                      print(f"Iteration {maxiter} reached")
660d9960f   quillotm   Adding n init par...
193
194
195
196
197
198
          return {
              "loss": loss,
              "C": self.C,
              "K": self.K,
              "L": self.L
          }