diff --git a/volia/clustering_modules/kmeans.py b/volia/clustering_modules/kmeans.py index 23ad00a..013c02b 100644 --- a/volia/clustering_modules/kmeans.py +++ b/volia/clustering_modules/kmeans.py @@ -33,11 +33,11 @@ class kmeans(): with open(model_path, "wb") as f: pickle.dump(self.kmeans_model, f) - def fit(self, features, k: int, tol: float, maxiter: int=300, debug: bool=False): + def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): """ @param features: @param k: @return: """ - self.kmeans_model = KMeans(n_clusters=k, n_init=10, random_state=0, max_iter=maxiter, tol=tol).fit(features) + self.kmeans_model = KMeans(n_clusters=k, n_init=ninit, random_state=0, max_iter=maxiter, tol=tol).fit(features)