Commit 1c1f0f29a768a0519933f876ecc8a9494e8984a4
1 parent
a9912f135f
Exists in
master
Now we train n_init time with the basic multidistance implementation of k-means.
Showing 1 changed file with 48 additions and 5 deletions Side-by-side Diff
volia/clustering_modules/kmeans_multidistance.py
| ... | ... | @@ -2,12 +2,18 @@ |
| 2 | 2 | import pickle |
| 3 | 3 | from abstract_clustering import AbstractClustering |
| 4 | 4 | from KMeans_Multidistance.KMeans_Class import KMeans |
| 5 | +from random import seed | |
| 6 | +from random import random | |
| 7 | +import numpy as np | |
| 8 | +from sklearn.metrics import pairwise_distances | |
| 5 | 9 | |
| 6 | 10 | class kmeansMultidistance(): |
| 7 | 11 | def __init__(self, distance="cosine"): |
| 8 | - self.kmeans_model = None | |
| 12 | + self.kmeans_model = None # Best model | |
| 9 | 13 | self.centroids = None |
| 10 | 14 | self.distance = distance |
| 15 | + self.seed = None # Seed of the best | |
| 16 | + self.seeds = None | |
| 11 | 17 | |
| 12 | 18 | def predict(self, features): |
| 13 | 19 | """ |
| ... | ... | @@ -48,8 +54,45 @@ |
| 48 | 54 | @param k: |
| 49 | 55 | @return: |
| 50 | 56 | """ |
| 51 | - model = KMeans(k=5, maxiter=maxiter, distance=self.distance, record_heterogeneity=[], verbose=True, seed=123) | |
| 52 | - centroids, _ = model.fit(features) | |
| 53 | - self.centroids = centroids | |
| 54 | - self.kmeans_model = model | |
| 57 | + # Initialization | |
| 58 | + self.kmeans_model = None | |
| 59 | + self.centroids = None | |
| 60 | + self.seed = None | |
| 61 | + | |
| 62 | + # Compute seeds before using seeds | |
| 63 | + seed() | |
| 64 | + self.seeds = [random() for i in range(ninit)] | |
| 65 | + | |
| 66 | + # Learning k-means model | |
| 67 | + results = [] | |
| 68 | + for i in range(ninit): | |
| 69 | + model = KMeans(k=5, | |
| 70 | + maxiter=maxiter, | |
| 71 | + distance=self.distance, | |
| 72 | + record_heterogeneity=[], | |
| 73 | + verbose=True, | |
| 74 | + seed=self.seeds[i]) | |
| 75 | + centroids, closest_cluster = model.fit(features) | |
| 76 | + | |
| 77 | + # Compute distance | |
| 78 | + kwds = {} | |
| 79 | + if self.distance == "mahalanobis": | |
| 80 | + VI = np.linalg.pinv(np.cov(features.T)).T | |
| 81 | + kwds = {"VI": VI} | |
| 82 | + distances = pairwise_distances(features, centroids, metric=self.distance, **kwds) | |
| 83 | + | |
| 84 | + # Then compute the loss | |
| 85 | + loss = np.sum(distances[np.arange(len(distances)), closest_cluster]) | |
| 86 | + | |
| 87 | + results.append({ | |
| 88 | + "model": model, | |
| 89 | + "centroids": centroids, | |
| 90 | + "seed": self.seeds[i], | |
| 91 | + "loss": loss | |
| 92 | + }) | |
| 93 | + losses = [result["loss"] for result in results] | |
| 94 | + best = results[losses.index(min(losses))] | |
| 95 | + self.kmeans_model = results[best]["model"] | |
| 96 | + self.centroids = results[best]["centroids"] | |
| 97 | + self.seed = results[best]["seed"] |