Commit 660d9960f95ade5bb2446df6177425097c9b71a9
1 parent
78e6974959
Exists in
master
Adding n init parameters
Showing 2 changed files with 53 additions and 28 deletions Side-by-side Diff
volia/clustering.py
| ... | ... | @@ -88,6 +88,25 @@ |
| 88 | 88 | @param output: output file if kmax not specified, else, output directory |
| 89 | 89 | @param mahalanobis: distance option of k-means. |
| 90 | 90 | """ |
| 91 | + json_content = locals().copy() | |
| 92 | + | |
| 93 | + def fit_model(k: int, output_file): | |
| 94 | + if debug: | |
| 95 | + print(f"Computing clustering with k={k}") | |
| 96 | + model = CLUSTERING_METHODS["k-means"] | |
| 97 | + if mahalanobis: | |
| 98 | + if debug: | |
| 99 | + print("Mahalanobis activated") | |
| 100 | + model = CLUSTERING_METHODS["k-means-mahalanobis"] | |
| 101 | + model.fit(X, k, tol, ninit, maxiter, debug) | |
| 102 | + model.save(output_file) | |
| 103 | + json_content["models"].append({ | |
| 104 | + "model_file": output_file, | |
| 105 | + "k": k, | |
| 106 | + }) | |
| 107 | + | |
| 108 | + json_content["models"] = [] | |
| 109 | + | |
| 91 | 110 | # -- READ FILES -- |
| 92 | 111 | features_dict = read_features(features) |
| 93 | 112 | lst_dict = read_lst(lst) |
| ... | ... | @@ -102,13 +121,7 @@ |
| 102 | 121 | |
| 103 | 122 | # Mono value case |
| 104 | 123 | if kmax is None and klist is None: |
| 105 | - if debug: | |
| 106 | - print(f"Computing clustering with k={k}") | |
| 107 | - model = CLUSTERING_METHODS["k-means"] | |
| 108 | - if mahalanobis: | |
| 109 | - model = CLUSTERING_METHODS["k-means-mahalanobis"] | |
| 110 | - model.fit(X, k, tol, maxiter, debug) | |
| 111 | - model.save(output) | |
| 124 | + fit_model(k, output) | |
| 112 | 125 | |
| 113 | 126 | # Multi values case with kmax |
| 114 | 127 | if kmax is not None: |
| ... | ... | @@ -116,11 +129,7 @@ |
| 116 | 129 | mkdir(output) |
| 117 | 130 | Ks = range(k, kmax + 1) |
| 118 | 131 | for i in Ks: |
| 119 | - model = CLUSTERING_METHODS["k-means"] | |
| 120 | - if mahalanobis: | |
| 121 | - model = CLUSTERING_METHODS["k-means-mahalanobis"] | |
| 122 | - model.fit(X, i, tol, maxiter, debug) | |
| 123 | - model.save(path.join(output, "clustering_" + str(i) + ".pkl")) | |
| 132 | + fit_model(i, path.join(output, "clustering_" + str(i) + ".pkl")) | |
| 124 | 133 | |
| 125 | 134 | # Second multi values case with klist |
| 126 | 135 | if klist is not None: |
| 127 | 136 | |
| ... | ... | @@ -128,16 +137,10 @@ |
| 128 | 137 | mkdir(output) |
| 129 | 138 | for k in klist: |
| 130 | 139 | k = int(k) |
| 131 | - model = CLUSTERING_METHODS["k-means"] | |
| 132 | - if mahalanobis: | |
| 133 | - model = CLUSTERING_METHODS["k-means-mahalanobis"] | |
| 134 | - model.fit(X, k, tol, maxiter, debug) | |
| 135 | - model.save(path.join(output, "clustering_" + str(k) + ".pkl")) | |
| 140 | + fit_model(k, path.join(output, "clustering_" + str(i) + ".pkl")) | |
| 136 | 141 | |
| 137 | - # TODO: Output json to explain the end parameters like number of iteration, tol reached and stoped the process ? | |
| 138 | - # etc. (what distance, what parameters etc) | |
| 139 | - # TODO: Move example data into a directory. | |
| 140 | - # TODO: Add example receipts | |
| 142 | + print(json_content) | |
| 143 | + # TODO: compute loss with k-means mahalanobis. | |
| 141 | 144 | # TODO: n_init have to be taken into account for mahalanobis case of k-means algorithm. |
| 142 | 145 | |
| 143 | 146 |
volia/clustering_modules/kmeans_mahalanobis.py
| ... | ... | @@ -60,8 +60,17 @@ |
| 60 | 60 | with open(modelpath, "wb") as f: |
| 61 | 61 | pickle.dump(data, f) |
| 62 | 62 | |
| 63 | - def fit(self, features, k: int, tol: float = 0.0001, maxiter: int=300, debug: bool=False): | |
| 64 | - self._train(features, k, tol, maxiter, debug) | |
| 63 | + def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): | |
| 64 | + results = [] | |
| 65 | + for i in range(ninit): | |
| 66 | + results.append(self._train(features, k, tol, maxiter, debug)) | |
| 67 | + losses = [v["loss"] for v in results] | |
| 68 | + best = results[losses.index(min(losses))] | |
| 69 | + if debug: | |
| 70 | + print(f"best: {best['loss']} loss") | |
| 71 | + self.C = best["C"] | |
| 72 | + self.L = best["L"] | |
| 73 | + self.K = best["K"] | |
| 65 | 74 | |
| 66 | 75 | def _initialize_model(self, X, number_clusters): |
| 67 | 76 | d = X.shape[1] |
| ... | ... | @@ -101,7 +110,6 @@ |
| 101 | 110 | N = X.shape[0] |
| 102 | 111 | d = X.shape[1] |
| 103 | 112 | |
| 104 | - X_embedded = None | |
| 105 | 113 | C, L = self._initialize_model(X, K) |
| 106 | 114 | self.C = C |
| 107 | 115 | self.L = L |
| 108 | 116 | |
| 109 | 117 | |
| 110 | 118 | |
| ... | ... | @@ -114,13 +122,18 @@ |
| 114 | 122 | print("Iteration: ", i) |
| 115 | 123 | |
| 116 | 124 | # Calcul matrix distance |
| 117 | - distances = np.zeros((N, K)) | |
| 125 | + distances = np.zeros((N, self.K)) | |
| 118 | 126 | |
| 119 | 127 | for n in range(N): |
| 120 | 128 | for k in range(self.K): |
| 121 | 129 | distances[n][k] = self._dist(X[n], self.C[k], self.L[k]) |
| 130 | + | |
| 122 | 131 | closest_cluster = np.argmin(distances, axis=1) |
| 132 | + loss = np.sum(distances[np.arange(len(distances)), closest_cluster]) | |
| 133 | + if debug: | |
| 134 | + print(f"loss {loss}") | |
| 123 | 135 | |
| 136 | + | |
| 124 | 137 | # -- Debug tool ---------------------- |
| 125 | 138 | if debug and i % 10 == 0: |
| 126 | 139 | # TSNE if needed |
| 127 | 140 | |
| 128 | 141 | |
| 129 | 142 | |
| ... | ... | @@ -161,17 +174,26 @@ |
| 161 | 174 | self.C[k] = C_new |
| 162 | 175 | self.L[k] = K_new |
| 163 | 176 | |
| 177 | + | |
| 164 | 178 | diff = np.sum(np.absolute((self.C - old_c) / old_c * 100)) |
| 165 | 179 | if diff > tol: |
| 166 | 180 | end_algo = False |
| 167 | 181 | if debug: |
| 168 | 182 | print(f"{diff}") |
| 169 | - elif debug: | |
| 170 | - print(f"Tolerance threshold {tol} reached with diff {diff}") | |
| 183 | + else: | |
| 184 | + if debug: | |
| 185 | + print(f"Tolerance threshold {tol} reached with diff {diff}") | |
| 171 | 186 | end_algo = True |
| 187 | + | |
| 172 | 188 | i = i + 1 |
| 173 | 189 | if i > maxiter: |
| 174 | 190 | end_algo = True |
| 175 | 191 | if debug: |
| 176 | 192 | print(f"Iteration {maxiter} reached") |
| 193 | + return { | |
| 194 | + "loss": loss, | |
| 195 | + "C": self.C, | |
| 196 | + "K": self.K, | |
| 197 | + "L": self.L | |
| 198 | + } |