Commit 78e69749597669ca1eeab42474eae0f73c132838
1 parent
957896bc97
Exists in
master
Adding n_init parameter
Showing 1 changed file with 2 additions and 2 deletions Inline Diff
volia/clustering_modules/kmeans.py
1 | 1 | ||
2 | from sklearn.cluster import KMeans | 2 | from sklearn.cluster import KMeans |
3 | import pickle | 3 | import pickle |
4 | from abstract_clustering import AbstractClustering | 4 | from abstract_clustering import AbstractClustering |
5 | 5 | ||
6 | class kmeans(): | 6 | class kmeans(): |
7 | def __init__(self): | 7 | def __init__(self): |
8 | self.kmeans_model = None | 8 | self.kmeans_model = None |
9 | 9 | ||
10 | def predict(self, features): | 10 | def predict(self, features): |
11 | """ | 11 | """ |
12 | 12 | ||
13 | @param features: | 13 | @param features: |
14 | @return: | 14 | @return: |
15 | """ | 15 | """ |
16 | return self.kmeans_model.predict(features) | 16 | return self.kmeans_model.predict(features) |
17 | 17 | ||
18 | def load(self, model_path: str): | 18 | def load(self, model_path: str): |
19 | """ | 19 | """ |
20 | 20 | ||
21 | @param model_path: | 21 | @param model_path: |
22 | @return: | 22 | @return: |
23 | """ | 23 | """ |
24 | with open(model_path, "rb") as f: | 24 | with open(model_path, "rb") as f: |
25 | self.kmeans_model = pickle.load(f) | 25 | self.kmeans_model = pickle.load(f) |
26 | 26 | ||
27 | def save(self, model_path: str): | 27 | def save(self, model_path: str): |
28 | """ | 28 | """ |
29 | 29 | ||
30 | @param model_path: | 30 | @param model_path: |
31 | @return: | 31 | @return: |
32 | """ | 32 | """ |
33 | with open(model_path, "wb") as f: | 33 | with open(model_path, "wb") as f: |
34 | pickle.dump(self.kmeans_model, f) | 34 | pickle.dump(self.kmeans_model, f) |
35 | 35 | ||
36 | def fit(self, features, k: int, tol: float, maxiter: int=300, debug: bool=False): | 36 | def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): |
37 | """ | 37 | """ |
38 | 38 | ||
39 | @param features: | 39 | @param features: |
40 | @param k: | 40 | @param k: |
41 | @return: | 41 | @return: |
42 | """ | 42 | """ |
43 | self.kmeans_model = KMeans(n_clusters=k, n_init=10, random_state=0, max_iter=maxiter, tol=tol).fit(features) | 43 | self.kmeans_model = KMeans(n_clusters=k, n_init=ninit, random_state=0, max_iter=maxiter, tol=tol).fit(features) |
44 | 44 |