From 05afc43e5473588a2a1667026689f13ce67bdf7e Mon Sep 17 00:00:00 2001 From: quillotm Date: Tue, 24 Aug 2021 09:08:10 +0200 Subject: [PATCH] Adding new implementation of kmeans to test. Simple implementation. Need to add to take into account n_init parameter and to add random seed. --- volia/clustering_modules/kmeans_multidistance.py | 54 ++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 volia/clustering_modules/kmeans_multidistance.py diff --git a/volia/clustering_modules/kmeans_multidistance.py b/volia/clustering_modules/kmeans_multidistance.py new file mode 100644 index 0000000..f1c7944 --- /dev/null +++ b/volia/clustering_modules/kmeans_multidistance.py @@ -0,0 +1,54 @@ + +import pickle +from abstract_clustering import AbstractClustering +from KMeans_Multidistance.KMeans_Class import KMeans + +class kmeansMultidistance(): + def __init__(self, distance="cosine"): + self.kmeans_model = None + self.centroids = None + self.distance = distance + + def predict(self, features): + """ + + @param features: + @return: + """ + return self.kmeans_model.assign_clusters(data=features, centroids=self.centroids, distance=self.kmeans_model.distance) + + def load(self, model_path: str): + """ + + @param model_path: + @return: + """ + with open(model_path, "rb") as f: + data = pickle.load(f) + self.kmeans_model = data["kmeans_model"] + self.centroids = data["centroids"] + self.distance = self.kmeans_model.distance + + def save(self, model_path: str): + """ + + @param model_path: + @return: + """ + with open(model_path, "wb") as f: + pickle.dump({ + "kmeans_model": self.kmeans_model, + "centroids": self.centroids + }, f) + + def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): + """ + + @param features: + @param k: + @return: + """ + model = KMeans(k=5, maxiter=maxiter, distance=self.distance, record_heterogeneity=[], verbose=True, seed=123) + centroids, _ = model.fit(features) + self.centroids = centroids + self.kmeans_model = model \ No newline at end of file -- 1.8.2.3