Commit 05afc43e5473588a2a1667026689f13ce67bdf7e
1 parent
2edadabee3
Exists in
master
Adding new implementation of kmeans to test. Simple implementation. Need to add …
…to take into account n_init parameter and to add random seed.
Showing 1 changed file with 54 additions and 0 deletions Inline Diff
volia/clustering_modules/kmeans_multidistance.py
File was created | 1 | ||
2 | import pickle | ||
3 | from abstract_clustering import AbstractClustering | ||
4 | from KMeans_Multidistance.KMeans_Class import KMeans | ||
5 | |||
6 | class kmeansMultidistance(): | ||
7 | def __init__(self, distance="cosine"): | ||
8 | self.kmeans_model = None | ||
9 | self.centroids = None | ||
10 | self.distance = distance | ||
11 | |||
12 | def predict(self, features): | ||
13 | """ | ||
14 | |||
15 | @param features: | ||
16 | @return: | ||
17 | """ | ||
18 | return self.kmeans_model.assign_clusters(data=features, centroids=self.centroids, distance=self.kmeans_model.distance) | ||
19 | |||
20 | def load(self, model_path: str): | ||
21 | """ | ||
22 | |||
23 | @param model_path: | ||
24 | @return: | ||
25 | """ | ||
26 | with open(model_path, "rb") as f: | ||
27 | data = pickle.load(f) | ||
28 | self.kmeans_model = data["kmeans_model"] | ||
29 | self.centroids = data["centroids"] | ||
30 | self.distance = self.kmeans_model.distance | ||
31 | |||
32 | def save(self, model_path: str): | ||
33 | """ | ||
34 | |||
35 | @param model_path: | ||
36 | @return: | ||
37 | """ | ||
38 | with open(model_path, "wb") as f: | ||
39 | pickle.dump({ | ||
40 | "kmeans_model": self.kmeans_model, | ||
41 | "centroids": self.centroids | ||
42 | }, f) | ||
43 | |||
44 | def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): | ||
45 | """ | ||
46 | |||
47 | @param features: | ||
48 | @param k: | ||
49 | @return: | ||
50 | """ | ||
51 | model = KMeans(k=5, maxiter=maxiter, distance=self.distance, record_heterogeneity=[], verbose=True, seed=123) | ||
52 | centroids, _ = model.fit(features) | ||
53 | self.centroids = centroids | ||
54 | self.kmeans_model = model |