Commit 1c1f0f29a768a0519933f876ecc8a9494e8984a4

Authored by quillotm
1 parent a9912f135f
Exists in master

Now we train n_init time with the basic multidistance implementation of k-means.

Showing 1 changed file with 48 additions and 5 deletions Inline Diff

volia/clustering_modules/kmeans_multidistance.py
1 1
2 import pickle 2 import pickle
3 from abstract_clustering import AbstractClustering 3 from abstract_clustering import AbstractClustering
4 from KMeans_Multidistance.KMeans_Class import KMeans 4 from KMeans_Multidistance.KMeans_Class import KMeans
5 from random import seed
6 from random import random
7 import numpy as np
8 from sklearn.metrics import pairwise_distances
5 9
6 class kmeansMultidistance(): 10 class kmeansMultidistance():
7 def __init__(self, distance="cosine"): 11 def __init__(self, distance="cosine"):
8 self.kmeans_model = None 12 self.kmeans_model = None # Best model
9 self.centroids = None 13 self.centroids = None
10 self.distance = distance 14 self.distance = distance
15 self.seed = None # Seed of the best
16 self.seeds = None
11 17
12 def predict(self, features): 18 def predict(self, features):
13 """ 19 """
14 20
15 @param features: 21 @param features:
16 @return: 22 @return:
17 """ 23 """
18 return self.kmeans_model.assign_clusters(data=features, centroids=self.centroids, distance=self.kmeans_model.distance) 24 return self.kmeans_model.assign_clusters(data=features, centroids=self.centroids, distance=self.kmeans_model.distance)
19 25
20 def load(self, model_path: str): 26 def load(self, model_path: str):
21 """ 27 """
22 28
23 @param model_path: 29 @param model_path:
24 @return: 30 @return:
25 """ 31 """
26 with open(model_path, "rb") as f: 32 with open(model_path, "rb") as f:
27 data = pickle.load(f) 33 data = pickle.load(f)
28 self.kmeans_model = data["kmeans_model"] 34 self.kmeans_model = data["kmeans_model"]
29 self.centroids = data["centroids"] 35 self.centroids = data["centroids"]
30 self.distance = self.kmeans_model.distance 36 self.distance = self.kmeans_model.distance
31 37
32 def save(self, model_path: str): 38 def save(self, model_path: str):
33 """ 39 """
34 40
35 @param model_path: 41 @param model_path:
36 @return: 42 @return:
37 """ 43 """
38 with open(model_path, "wb") as f: 44 with open(model_path, "wb") as f:
39 pickle.dump({ 45 pickle.dump({
40 "kmeans_model": self.kmeans_model, 46 "kmeans_model": self.kmeans_model,
41 "centroids": self.centroids 47 "centroids": self.centroids
42 }, f) 48 }, f)
43 49
44 def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): 50 def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False):
45 """ 51 """
46 52
47 @param features: 53 @param features:
48 @param k: 54 @param k:
49 @return: 55 @return:
50 """ 56 """
51 model = KMeans(k=5, maxiter=maxiter, distance=self.distance, record_heterogeneity=[], verbose=True, seed=123) 57 # Initialization
52 centroids, _ = model.fit(features) 58 self.kmeans_model = None
53 self.centroids = centroids 59 self.centroids = None
54 self.kmeans_model = model 60 self.seed = None
61
62 # Compute seeds before using seeds
63 seed()
64 self.seeds = [random() for i in range(ninit)]
65
66 # Learning k-means model
67 results = []
68 for i in range(ninit):
69 model = KMeans(k=5,
70 maxiter=maxiter,
71 distance=self.distance,
72 record_heterogeneity=[],
73 verbose=True,
74 seed=self.seeds[i])
75 centroids, closest_cluster = model.fit(features)
76
77 # Compute distance
78 kwds = {}
79 if self.distance == "mahalanobis":
80 VI = np.linalg.pinv(np.cov(features.T)).T
81 kwds = {"VI": VI}
82 distances = pairwise_distances(features, centroids, metric=self.distance, **kwds)
83
84 # Then compute the loss
85 loss = np.sum(distances[np.arange(len(distances)), closest_cluster])
86
87 results.append({
88 "model": model,
89 "centroids": centroids,
90 "seed": self.seeds[i],
91 "loss": loss
92 })
93 losses = [result["loss"] for result in results]
94 best = results[losses.index(min(losses))]
95 self.kmeans_model = results[best]["model"]
96 self.centroids = results[best]["centroids"]
97 self.seed = results[best]["seed"]
98