Commit 78e69749597669ca1eeab42474eae0f73c132838
1 parent
957896bc97
Exists in
master
Adding n_init parameter
Showing 1 changed file with 2 additions and 2 deletions Side-by-side Diff
volia/clustering_modules/kmeans.py
| ... | ... | @@ -33,12 +33,12 @@ |
| 33 | 33 | with open(model_path, "wb") as f: |
| 34 | 34 | pickle.dump(self.kmeans_model, f) |
| 35 | 35 | |
| 36 | - def fit(self, features, k: int, tol: float, maxiter: int=300, debug: bool=False): | |
| 36 | + def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): | |
| 37 | 37 | """ |
| 38 | 38 | |
| 39 | 39 | @param features: |
| 40 | 40 | @param k: |
| 41 | 41 | @return: |
| 42 | 42 | """ |
| 43 | - self.kmeans_model = KMeans(n_clusters=k, n_init=10, random_state=0, max_iter=maxiter, tol=tol).fit(features) | |
| 43 | + self.kmeans_model = KMeans(n_clusters=k, n_init=ninit, random_state=0, max_iter=maxiter, tol=tol).fit(features) |