Commit 78e69749597669ca1eeab42474eae0f73c132838

Authored by quillotm
1 parent 957896bc97
Exists in master

Adding n_init parameter

Showing 1 changed file with 2 additions and 2 deletions Inline Diff

volia/clustering_modules/kmeans.py
1 1
2 from sklearn.cluster import KMeans 2 from sklearn.cluster import KMeans
3 import pickle 3 import pickle
4 from abstract_clustering import AbstractClustering 4 from abstract_clustering import AbstractClustering
5 5
6 class kmeans(): 6 class kmeans():
7 def __init__(self): 7 def __init__(self):
8 self.kmeans_model = None 8 self.kmeans_model = None
9 9
10 def predict(self, features): 10 def predict(self, features):
11 """ 11 """
12 12
13 @param features: 13 @param features:
14 @return: 14 @return:
15 """ 15 """
16 return self.kmeans_model.predict(features) 16 return self.kmeans_model.predict(features)
17 17
18 def load(self, model_path: str): 18 def load(self, model_path: str):
19 """ 19 """
20 20
21 @param model_path: 21 @param model_path:
22 @return: 22 @return:
23 """ 23 """
24 with open(model_path, "rb") as f: 24 with open(model_path, "rb") as f:
25 self.kmeans_model = pickle.load(f) 25 self.kmeans_model = pickle.load(f)
26 26
27 def save(self, model_path: str): 27 def save(self, model_path: str):
28 """ 28 """
29 29
30 @param model_path: 30 @param model_path:
31 @return: 31 @return:
32 """ 32 """
33 with open(model_path, "wb") as f: 33 with open(model_path, "wb") as f:
34 pickle.dump(self.kmeans_model, f) 34 pickle.dump(self.kmeans_model, f)
35 35
36 def fit(self, features, k: int, tol: float, maxiter: int=300, debug: bool=False): 36 def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False):
37 """ 37 """
38 38
39 @param features: 39 @param features:
40 @param k: 40 @param k:
41 @return: 41 @return:
42 """ 42 """
43 self.kmeans_model = KMeans(n_clusters=k, n_init=10, random_state=0, max_iter=maxiter, tol=tol).fit(features) 43 self.kmeans_model = KMeans(n_clusters=k, n_init=ninit, random_state=0, max_iter=maxiter, tol=tol).fit(features)
44 44