Commit 78e69749597669ca1eeab42474eae0f73c132838
1 parent
957896bc97
Exists in
master
Adding n_init parameter
Showing 1 changed file with 2 additions and 2 deletions Side-by-side Diff
volia/clustering_modules/kmeans.py
... | ... | @@ -33,12 +33,12 @@ |
33 | 33 | with open(model_path, "wb") as f: |
34 | 34 | pickle.dump(self.kmeans_model, f) |
35 | 35 | |
36 | - def fit(self, features, k: int, tol: float, maxiter: int=300, debug: bool=False): | |
36 | + def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): | |
37 | 37 | """ |
38 | 38 | |
39 | 39 | @param features: |
40 | 40 | @param k: |
41 | 41 | @return: |
42 | 42 | """ |
43 | - self.kmeans_model = KMeans(n_clusters=k, n_init=10, random_state=0, max_iter=maxiter, tol=tol).fit(features) | |
43 | + self.kmeans_model = KMeans(n_clusters=k, n_init=ninit, random_state=0, max_iter=maxiter, tol=tol).fit(features) |