Commit 05afc43e5473588a2a1667026689f13ce67bdf7e
1 parent
2edadabee3
Exists in
master
Adding new implementation of kmeans to test. Simple implementation. Need to add …
…to take into account n_init parameter and to add random seed.
Showing 1 changed file with 54 additions and 0 deletions Side-by-side Diff
volia/clustering_modules/kmeans_multidistance.py
1 | + | |
2 | +import pickle | |
3 | +from abstract_clustering import AbstractClustering | |
4 | +from KMeans_Multidistance.KMeans_Class import KMeans | |
5 | + | |
6 | +class kmeansMultidistance(): | |
7 | + def __init__(self, distance="cosine"): | |
8 | + self.kmeans_model = None | |
9 | + self.centroids = None | |
10 | + self.distance = distance | |
11 | + | |
12 | + def predict(self, features): | |
13 | + """ | |
14 | + | |
15 | + @param features: | |
16 | + @return: | |
17 | + """ | |
18 | + return self.kmeans_model.assign_clusters(data=features, centroids=self.centroids, distance=self.kmeans_model.distance) | |
19 | + | |
20 | + def load(self, model_path: str): | |
21 | + """ | |
22 | + | |
23 | + @param model_path: | |
24 | + @return: | |
25 | + """ | |
26 | + with open(model_path, "rb") as f: | |
27 | + data = pickle.load(f) | |
28 | + self.kmeans_model = data["kmeans_model"] | |
29 | + self.centroids = data["centroids"] | |
30 | + self.distance = self.kmeans_model.distance | |
31 | + | |
32 | + def save(self, model_path: str): | |
33 | + """ | |
34 | + | |
35 | + @param model_path: | |
36 | + @return: | |
37 | + """ | |
38 | + with open(model_path, "wb") as f: | |
39 | + pickle.dump({ | |
40 | + "kmeans_model": self.kmeans_model, | |
41 | + "centroids": self.centroids | |
42 | + }, f) | |
43 | + | |
44 | + def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): | |
45 | + """ | |
46 | + | |
47 | + @param features: | |
48 | + @param k: | |
49 | + @return: | |
50 | + """ | |
51 | + model = KMeans(k=5, maxiter=maxiter, distance=self.distance, record_heterogeneity=[], verbose=True, seed=123) | |
52 | + centroids, _ = model.fit(features) | |
53 | + self.centroids = centroids | |
54 | + self.kmeans_model = model |