Blame view
volia/clustering_modules/kmeans_multidistance.py
1.52 KB
05afc43e5
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import pickle from abstract_clustering import AbstractClustering from KMeans_Multidistance.KMeans_Class import KMeans class kmeansMultidistance(): def __init__(self, distance="cosine"): self.kmeans_model = None self.centroids = None self.distance = distance def predict(self, features): """ @param features: @return: """ return self.kmeans_model.assign_clusters(data=features, centroids=self.centroids, distance=self.kmeans_model.distance) def load(self, model_path: str): """ @param model_path: @return: """ with open(model_path, "rb") as f: data = pickle.load(f) self.kmeans_model = data["kmeans_model"] self.centroids = data["centroids"] self.distance = self.kmeans_model.distance def save(self, model_path: str): """ @param model_path: @return: """ with open(model_path, "wb") as f: pickle.dump({ "kmeans_model": self.kmeans_model, "centroids": self.centroids }, f) def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False): """ @param features: @param k: @return: """ model = KMeans(k=5, maxiter=maxiter, distance=self.distance, record_heterogeneity=[], verbose=True, seed=123) centroids, _ = model.fit(features) self.centroids = centroids self.kmeans_model = model |