Commit 1c1f0f29a768a0519933f876ecc8a9494e8984a4
1 parent
a9912f135f
Exists in
master
Now we train n_init time with the basic multidistance implementation of k-means.
Showing 1 changed file with 48 additions and 5 deletions Side-by-side Diff
volia/clustering_modules/kmeans_multidistance.py
... | ... | @@ -2,12 +2,18 @@ |
2 | 2 | import pickle |
3 | 3 | from abstract_clustering import AbstractClustering |
4 | 4 | from KMeans_Multidistance.KMeans_Class import KMeans |
5 | +from random import seed | |
6 | +from random import random | |
7 | +import numpy as np | |
8 | +from sklearn.metrics import pairwise_distances | |
5 | 9 | |
6 | 10 | class kmeansMultidistance(): |
7 | 11 | def __init__(self, distance="cosine"): |
8 | - self.kmeans_model = None | |
12 | + self.kmeans_model = None # Best model | |
9 | 13 | self.centroids = None |
10 | 14 | self.distance = distance |
15 | + self.seed = None # Seed of the best | |
16 | + self.seeds = None | |
11 | 17 | |
12 | 18 | def predict(self, features): |
13 | 19 | """ |
... | ... | @@ -48,8 +54,45 @@ |
48 | 54 | @param k: |
49 | 55 | @return: |
50 | 56 | """ |
51 | - model = KMeans(k=5, maxiter=maxiter, distance=self.distance, record_heterogeneity=[], verbose=True, seed=123) | |
52 | - centroids, _ = model.fit(features) | |
53 | - self.centroids = centroids | |
54 | - self.kmeans_model = model | |
57 | + # Initialization | |
58 | + self.kmeans_model = None | |
59 | + self.centroids = None | |
60 | + self.seed = None | |
61 | + | |
62 | + # Compute seeds before using seeds | |
63 | + seed() | |
64 | + self.seeds = [random() for i in range(ninit)] | |
65 | + | |
66 | + # Learning k-means model | |
67 | + results = [] | |
68 | + for i in range(ninit): | |
69 | + model = KMeans(k=5, | |
70 | + maxiter=maxiter, | |
71 | + distance=self.distance, | |
72 | + record_heterogeneity=[], | |
73 | + verbose=True, | |
74 | + seed=self.seeds[i]) | |
75 | + centroids, closest_cluster = model.fit(features) | |
76 | + | |
77 | + # Compute distance | |
78 | + kwds = {} | |
79 | + if self.distance == "mahalanobis": | |
80 | + VI = np.linalg.pinv(np.cov(features.T)).T | |
81 | + kwds = {"VI": VI} | |
82 | + distances = pairwise_distances(features, centroids, metric=self.distance, **kwds) | |
83 | + | |
84 | + # Then compute the loss | |
85 | + loss = np.sum(distances[np.arange(len(distances)), closest_cluster]) | |
86 | + | |
87 | + results.append({ | |
88 | + "model": model, | |
89 | + "centroids": centroids, | |
90 | + "seed": self.seeds[i], | |
91 | + "loss": loss | |
92 | + }) | |
93 | + losses = [result["loss"] for result in results] | |
94 | + best = results[losses.index(min(losses))] | |
95 | + self.kmeans_model = results[best]["model"] | |
96 | + self.centroids = results[best]["centroids"] | |
97 | + self.seed = results[best]["seed"] |