Commit d4507c2683cd0fcec7ee302ad1a4e430b61cd6cd

Authored by quillotm
1 parent 4152e83df2
Exists in master

We do not update a cluster if it is not associated to any data point.

Showing 1 changed file with 3 additions and 3 deletions Inline Diff

volia/clustering_modules/kmeans_mahalanobis.py
1 1
2 2
3 from sklearn.cluster import KMeans 3 from sklearn.cluster import KMeans
4 import pickle 4 import pickle
5 import numpy as np 5 import numpy as np
6 import matplotlib.pyplot as plt 6 import matplotlib.pyplot as plt
7 from sklearn.manifold import TSNE 7 from sklearn.manifold import TSNE
8 from abstract_clustering import AbstractClustering 8 from abstract_clustering import AbstractClustering
9 9
10 class kmeansMahalanobis(): 10 class kmeansMahalanobis():
11 def __init__(self): 11 def __init__(self):
12 """ 12 """
13 13
14 """ 14 """
15 self.C = None 15 self.C = None
16 self.L = None 16 self.L = None
17 self.K = None 17 self.K = None
18 18
19 def predict(self, features): 19 def predict(self, features):
20 """ 20 """
21 21
22 @param features: 22 @param features:
23 @return: 23 @return:
24 """ 24 """
25 N = features.shape[0] 25 N = features.shape[0]
26 distances = np.zeros((N, self.K)) 26 distances = np.zeros((N, self.K))
27 for n in range(N): 27 for n in range(N):
28 for k in range(self.K): 28 for k in range(self.K):
29 distances[n][k] = self._dist(features[n], self.C[k], self.L[k]) 29 distances[n][k] = self._dist(features[n], self.C[k], self.L[k])
30 print(distances)
31 closest_cluster = np.argmin(distances, axis=1) 30 closest_cluster = np.argmin(distances, axis=1)
32 return closest_cluster 31 return closest_cluster
33 32
34 def load(self, model_path): 33 def load(self, model_path):
35 """ 34 """
36 35
37 @param model_path: 36 @param model_path:
38 @return: 37 @return:
39 """ 38 """
40 data = None 39 data = None
41 with open(model_path): 40 with open(model_path):
42 data = pickle.load() 41 data = pickle.load()
43 if data is None: 42 if data is None:
44 raise Exception("Le modèle n'a pas pu être chargé") 43 raise Exception("Le modèle n'a pas pu être chargé")
45 else: 44 else:
46 self.C = data["C"] 45 self.C = data["C"]
47 self.L = data["L"] 46 self.L = data["L"]
48 self.K = data["K"] 47 self.K = data["K"]
49 48
50 def save(self, modelpath: str): 49 def save(self, modelpath: str):
51 """ 50 """
52 51
53 @param modelpath: 52 @param modelpath:
54 @return: 53 @return:
55 """ 54 """
56 data = { 55 data = {
57 "C": self.C, 56 "C": self.C,
58 "L": self.L, 57 "L": self.L,
59 "K": self.K 58 "K": self.K
60 } 59 }
61 with open(modelpath, "wb") as f: 60 with open(modelpath, "wb") as f:
62 pickle.dump(data, f) 61 pickle.dump(data, f)
63 62
64 def fit(self, features, K: int): 63 def fit(self, features, K: int):
65 self._train(features, K) 64 self._train(features, K)
66 65
67 def _initialize_model(self, X, number_clusters): 66 def _initialize_model(self, X, number_clusters):
68 d = X.shape[1] 67 d = X.shape[1]
69 C = X[np.random.choice(X.shape[0], number_clusters)] 68 C = X[np.random.choice(X.shape[0], number_clusters)]
70 L = np.zeros((number_clusters, d, d)) 69 L = np.zeros((number_clusters, d, d))
71 for k in range(number_clusters): 70 for k in range(number_clusters):
72 L[k] = np.identity(d) 71 L[k] = np.identity(d)
73 return C, L 72 return C, L
74 73
75 def _dist(self, a, b, l): 74 def _dist(self, a, b, l):
76 ''' 75 '''
77 Distance euclidienne 76 Distance euclidienne
78 ''' 77 '''
79 a = np.reshape(a, (-1, 1)) 78 a = np.reshape(a, (-1, 1))
80 b = np.reshape(b, (-1, 1)) 79 b = np.reshape(b, (-1, 1))
81 result = np.transpose(a - b).dot(l).dot(a-b)[0][0] 80 result = np.transpose(a - b).dot(l).dot(a-b)[0][0]
82 return result 81 return result
83 82
84 def _plot_iteration(self, iteration, points, clusters, centers): 83 def _plot_iteration(self, iteration, points, clusters, centers):
85 fig = plt.figure() 84 fig = plt.figure()
86 ax = fig.add_subplot(111) 85 ax = fig.add_subplot(111)
87 scatter = ax.scatter(points[:, 0], points[:, 1], c=clusters, s=50) 86 scatter = ax.scatter(points[:, 0], points[:, 1], c=clusters, s=50)
88 87
89 #for center in centers: 88 #for center in centers:
90 # ax.scatter(center[0], center[1], s=50, c='red', marker='+') 89 # ax.scatter(center[0], center[1], s=50, c='red', marker='+')
91 ax.scatter(centers[:, 0], centers[:, 1], s=50, c='red', marker='+') 90 ax.scatter(centers[:, 0], centers[:, 1], s=50, c='red', marker='+')
92 91
93 ax.set_xlabel('x') 92 ax.set_xlabel('x')
94 ax.set_ylabel('y') 93 ax.set_ylabel('y')
95 plt.colorbar(scatter) 94 plt.colorbar(scatter)
96 #plt.ylim(0, 1) 95 #plt.ylim(0, 1)
97 #plt.xlim(0, 1) 96 #plt.xlim(0, 1)
98 plt.savefig("test_" + str(iteration) + ".pdf") 97 plt.savefig("test_" + str(iteration) + ".pdf")
99 98
100 def _train(self, features, K: int): 99 def _train(self, features, K: int):
101 X = features 100 X = features
102 N = X.shape[0] 101 N = X.shape[0]
103 d = X.shape[1] 102 d = X.shape[1]
104 103
105 C, L = self._initialize_model(X, K) 104 C, L = self._initialize_model(X, K)
106 self.C = C 105 self.C = C
107 self.L = L 106 self.L = L
108 self.K = K 107 self.K = K
109 108
110 end_algo = False 109 end_algo = False
111 i = 0 110 i = 0
112 while not end_algo: 111 while not end_algo:
113 if i == 10: 112 if i == 10:
114 exit(1) 113 exit(1)
115 print("Iteration: ", i) 114 print("Iteration: ", i)
116 # Calcul matrix distance 115 # Calcul matrix distance
117 distances = np.zeros((N, K)) 116 distances = np.zeros((N, K))
118 117
119 for n in range(N): 118 for n in range(N):
120 for k in range(self.K): 119 for k in range(self.K):
121 distances[n][k] = self._dist(X[n], self.C[k], self.L[k]) 120 distances[n][k] = self._dist(X[n], self.C[k], self.L[k])
122 print(distances)
123 closest_cluster = np.argmin(distances, axis=1) 121 closest_cluster = np.argmin(distances, axis=1)
124 if i % 1 == 0: 122 if i % 1 == 0:
125 # -- Debug tool ---------------------- 123 # -- Debug tool ----------------------
126 # TSNE 124 # TSNE
127 #X_embedded = np.concatenate((X, self.C), axis=0) 125 #X_embedded = np.concatenate((X, self.C), axis=0)
128 X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, C), axis=0)) 126 X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, C), axis=0))
129 # Then plot 127 # Then plot
130 self._plot_iteration( 128 self._plot_iteration(
131 i, 129 i,
132 X_embedded[:X.shape[0]], 130 X_embedded[:X.shape[0]],
133 closest_cluster, 131 closest_cluster,
134 X_embedded[X.shape[0]:] 132 X_embedded[X.shape[0]:]
135 ) 133 )
136 # ------------------------------------ 134 # ------------------------------------
137 135
138 end_algo = True 136 end_algo = True
139 for k in range(K): 137 for k in range(K):
140 # Find subset of X with values closed to the centroid c_k. 138 # Find subset of X with values closed to the centroid c_k.
141 X_sub = np.where(closest_cluster == k) 139 X_sub = np.where(closest_cluster == k)
142 X_sub = np.take(X, X_sub[0], axis=0) 140 X_sub = np.take(X, X_sub[0], axis=0)
141 if X_sub.shape[0] == 0:
142 continue
143 np.mean(X_sub, axis=0) 143 np.mean(X_sub, axis=0)
144 C_new = np.mean(X_sub, axis=0) 144 C_new = np.mean(X_sub, axis=0)
145 145
146 # -- COMPUTE NEW LAMBDA (here named K) -- 146 # -- COMPUTE NEW LAMBDA (here named K) --
147 K_new = np.zeros((L.shape[1], L.shape[2])) 147 K_new = np.zeros((L.shape[1], L.shape[2]))
148 for x in X_sub: 148 for x in X_sub:
149 x = np.reshape(x, (-1, 1)) 149 x = np.reshape(x, (-1, 1))
150 c_tmp = np.reshape(C_new, (-1, 1)) 150 c_tmp = np.reshape(C_new, (-1, 1))
151 K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose()) 151 K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose())
152 K_new = K_new / X_sub.shape[0] 152 K_new = K_new / X_sub.shape[0]
153 K_new = np.linalg.inv(K_new) 153 K_new = np.linalg.pinv(K_new)
154 154
155 if end_algo and (not (self.C[k] == C_new).all()): # If the same stop 155 if end_algo and (not (self.C[k] == C_new).all()): # If the same stop
156 end_algo = False 156 end_algo = False
157 self.C[k] = C_new 157 self.C[k] = C_new
158 self.L[k] = K_new 158 self.L[k] = K_new