Commit 05afc43e5473588a2a1667026689f13ce67bdf7e

Authored by quillotm
1 parent 2edadabee3
Exists in master

Adding new implementation of kmeans to test. Simple implementation. Need to add …

…to take into account n_init parameter and to add random seed.

Showing 1 changed file with 54 additions and 0 deletions Inline Diff

volia/clustering_modules/kmeans_multidistance.py
File was created 1
2 import pickle
3 from abstract_clustering import AbstractClustering
4 from KMeans_Multidistance.KMeans_Class import KMeans
5
6 class kmeansMultidistance():
7 def __init__(self, distance="cosine"):
8 self.kmeans_model = None
9 self.centroids = None
10 self.distance = distance
11
12 def predict(self, features):
13 """
14
15 @param features:
16 @return:
17 """
18 return self.kmeans_model.assign_clusters(data=features, centroids=self.centroids, distance=self.kmeans_model.distance)
19
20 def load(self, model_path: str):
21 """
22
23 @param model_path:
24 @return:
25 """
26 with open(model_path, "rb") as f:
27 data = pickle.load(f)
28 self.kmeans_model = data["kmeans_model"]
29 self.centroids = data["centroids"]
30 self.distance = self.kmeans_model.distance
31
32 def save(self, model_path: str):
33 """
34
35 @param model_path:
36 @return:
37 """
38 with open(model_path, "wb") as f:
39 pickle.dump({
40 "kmeans_model": self.kmeans_model,
41 "centroids": self.centroids
42 }, f)
43
44 def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False):
45 """
46
47 @param features:
48 @param k:
49 @return:
50 """
51 model = KMeans(k=5, maxiter=maxiter, distance=self.distance, record_heterogeneity=[], verbose=True, seed=123)
52 centroids, _ = model.fit(features)
53 self.centroids = centroids
54 self.kmeans_model = model