Commit 05afc43e5473588a2a1667026689f13ce67bdf7e

Authored by quillotm
1 parent 2edadabee3
Exists in master

Adding new implementation of kmeans to test. Simple implementation. Need to add …

…to take into account n_init parameter and to add random seed.

Showing 1 changed file with 54 additions and 0 deletions Side-by-side Diff

volia/clustering_modules/kmeans_multidistance.py
  1 +
  2 +import pickle
  3 +from abstract_clustering import AbstractClustering
  4 +from KMeans_Multidistance.KMeans_Class import KMeans
  5 +
  6 +class kmeansMultidistance():
  7 + def __init__(self, distance="cosine"):
  8 + self.kmeans_model = None
  9 + self.centroids = None
  10 + self.distance = distance
  11 +
  12 + def predict(self, features):
  13 + """
  14 +
  15 + @param features:
  16 + @return:
  17 + """
  18 + return self.kmeans_model.assign_clusters(data=features, centroids=self.centroids, distance=self.kmeans_model.distance)
  19 +
  20 + def load(self, model_path: str):
  21 + """
  22 +
  23 + @param model_path:
  24 + @return:
  25 + """
  26 + with open(model_path, "rb") as f:
  27 + data = pickle.load(f)
  28 + self.kmeans_model = data["kmeans_model"]
  29 + self.centroids = data["centroids"]
  30 + self.distance = self.kmeans_model.distance
  31 +
  32 + def save(self, model_path: str):
  33 + """
  34 +
  35 + @param model_path:
  36 + @return:
  37 + """
  38 + with open(model_path, "wb") as f:
  39 + pickle.dump({
  40 + "kmeans_model": self.kmeans_model,
  41 + "centroids": self.centroids
  42 + }, f)
  43 +
  44 + def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False):
  45 + """
  46 +
  47 + @param features:
  48 + @param k:
  49 + @return:
  50 + """
  51 + model = KMeans(k=5, maxiter=maxiter, distance=self.distance, record_heterogeneity=[], verbose=True, seed=123)
  52 + centroids, _ = model.fit(features)
  53 + self.centroids = centroids
  54 + self.kmeans_model = model