Commit 05afc43e5473588a2a1667026689f13ce67bdf7e
1 parent
2edadabee3
Exists in
master
Adding new implementation of kmeans to test. Simple implementation. Need to add …
…to take into account n_init parameter and to add random seed.
Showing 1 changed file with 54 additions and 0 deletions Side-by-side Diff
volia/clustering_modules/kmeans_multidistance.py
| 1 | + | |
| 2 | +import pickle | |
| 3 | +from abstract_clustering import AbstractClustering | |
| 4 | +from KMeans_Multidistance.KMeans_Class import KMeans | |
| 5 | + | |
| 6 | +class kmeansMultidistance(): | |
| 7 | + def __init__(self, distance="cosine"): | |
| 8 | + self.kmeans_model = None | |
| 9 | + self.centroids = None | |
| 10 | + self.distance = distance | |
| 11 | + | |
| 12 | + def predict(self, features): | |
| 13 | + """ | |
| 14 | + | |
| 15 | + @param features: | |
| 16 | + @return: | |
| 17 | + """ | |
| 18 | + return self.kmeans_model.assign_clusters(data=features, centroids=self.centroids, distance=self.kmeans_model.distance) | |
| 19 | + | |
| 20 | + def load(self, model_path: str): | |
| 21 | + """ | |
| 22 | + | |
| 23 | + @param model_path: | |
| 24 | + @return: | |
| 25 | + """ | |
| 26 | + with open(model_path, "rb") as f: | |
| 27 | + data = pickle.load(f) | |
| 28 | + self.kmeans_model = data["kmeans_model"] | |
| 29 | + self.centroids = data["centroids"] | |
| 30 | + self.distance = self.kmeans_model.distance | |
| 31 | + | |
| 32 | + def save(self, model_path: str): | |
| 33 | + """ | |
| 34 | + | |
| 35 | + @param model_path: | |
| 36 | + @return: | |
| 37 | + """ | |
| 38 | + with open(model_path, "wb") as f: | |
| 39 | + pickle.dump({ | |
| 40 | + "kmeans_model": self.kmeans_model, | |
| 41 | + "centroids": self.centroids | |
| 42 | + }, f) | |
| 43 | + | |
| 44 | + def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): | |
| 45 | + """ | |
| 46 | + | |
| 47 | + @param features: | |
| 48 | + @param k: | |
| 49 | + @return: | |
| 50 | + """ | |
| 51 | + model = KMeans(k=5, maxiter=maxiter, distance=self.distance, record_heterogeneity=[], verbose=True, seed=123) | |
| 52 | + centroids, _ = model.fit(features) | |
| 53 | + self.centroids = centroids | |
| 54 | + self.kmeans_model = model |