Blame view

volia/clustering_modules/kmeans_multidistance.py 2.92 KB
05afc43e5   quillotm   Adding new implem...
1
2
3
4
  
  import pickle
  from abstract_clustering import AbstractClustering
  from KMeans_Multidistance.KMeans_Class import KMeans
1c1f0f29a   quillotm   Now we train n_in...
5
  from random import seed
0774ae544   quillotm   Fixed issues
6
  from random import randint
1c1f0f29a   quillotm   Now we train n_in...
7
8
  import numpy as np
  from sklearn.metrics import pairwise_distances
05afc43e5   quillotm   Adding new implem...
9
10
11
  
  class kmeansMultidistance():
      def __init__(self, distance="cosine"):
1c1f0f29a   quillotm   Now we train n_in...
12
          self.kmeans_model = None # Best model
05afc43e5   quillotm   Adding new implem...
13
14
          self.centroids = None
          self.distance = distance
1c1f0f29a   quillotm   Now we train n_in...
15
16
          self.seed = None # Seed of the best
          self.seeds = None
05afc43e5   quillotm   Adding new implem...
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  
      def predict(self, features):
          """
  
          @param features:
          @return:
          """
          return self.kmeans_model.assign_clusters(data=features, centroids=self.centroids, distance=self.kmeans_model.distance)
  
      def load(self, model_path: str):
          """
  
          @param model_path:
          @return:
          """
          with open(model_path, "rb") as f:
              data = pickle.load(f)
              self.kmeans_model = data["kmeans_model"]
              self.centroids = data["centroids"]
              self.distance = self.kmeans_model.distance
  
      def save(self, model_path: str):
          """
  
          @param model_path:
          @return:
          """
          with open(model_path, "wb") as f:
              pickle.dump({
                  "kmeans_model": self.kmeans_model,
                  "centroids": self.centroids
              }, f)
  
      def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False):
          """
  
          @param features:
          @param k:
          @return:
          """
1c1f0f29a   quillotm   Now we train n_in...
57
58
59
60
61
62
63
          # Initialization
          self.kmeans_model = None
          self.centroids = None
          self.seed = None
  
          # Compute seeds before using seeds
          seed()
0774ae544   quillotm   Fixed issues
64
          self.seeds = [randint(1, 100000) for i in range(ninit)]
1c1f0f29a   quillotm   Now we train n_in...
65
66
67
68
  
          # Learning k-means model
          results = []
          for i in range(ninit):
3ed53b423   Mathias Quillot   just fixed and error
69
              model = KMeans(k=k,
1c1f0f29a   quillotm   Now we train n_in...
70
71
72
                             maxiter=maxiter,
                             distance=self.distance,
                             record_heterogeneity=[],
0774ae544   quillotm   Fixed issues
73
                             verbose=debug,
1c1f0f29a   quillotm   Now we train n_in...
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
                             seed=self.seeds[i])
              centroids, closest_cluster = model.fit(features)
  
              # Compute distance
              kwds = {}
              if self.distance == "mahalanobis":
                  VI = np.linalg.pinv(np.cov(features.T)).T
                  kwds = {"VI": VI}
              distances = pairwise_distances(features, centroids, metric=self.distance, **kwds)
  
              # Then compute the loss
              loss = np.sum(distances[np.arange(len(distances)), closest_cluster])
  
              results.append({
                  "model": model,
                  "centroids": centroids,
                  "seed": self.seeds[i],
                  "loss": loss
              })
          losses = [result["loss"] for result in results]
          best = results[losses.index(min(losses))]
0774ae544   quillotm   Fixed issues
95
96
97
          self.kmeans_model = best["model"]
          self.centroids = best["centroids"]
          self.seed = best["seed"]