Blame view

volia/clustering_modules/kmeans.py 931 Bytes
9191399c3   quillotm   Clustering and ev...
1
2
3
4
5
6
7
8
9
10
  
  from sklearn.cluster import KMeans
  import pickle
  from abstract_clustering import AbstractClustering
  
  class kmeans():
      def __init__(self):
          self.kmeans_model = None
  
      def predict(self, features):
4152e83df   quillotm   Addind kmeans mah...
11
12
13
14
15
          """
  
          @param features:
          @return:
          """
9191399c3   quillotm   Clustering and ev...
16
          return self.kmeans_model.predict(features)
4152e83df   quillotm   Addind kmeans mah...
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
      def load(self, model_path: str):
          """
  
          @param model_path:
          @return:
          """
          with open(model_path, "rb") as f:
              self.kmeans_model = pickle.load(f)
  
      def save(self, model_path: str):
          """
  
          @param model_path:
          @return:
          """
          with open(model_path, "wb") as f:
              pickle.dump(self.kmeans_model, f)
  
      def fit(self, features, k: int):
          """
  
          @param features:
          @param k:
          @return:
          """
          self.kmeans_model = KMeans(n_clusters=k, n_init=10, random_state=0).fit(features)