kmeans_multidistance.py 1.52 KB
import pickle
from abstract_clustering import AbstractClustering
from KMeans_Multidistance.KMeans_Class import KMeans

class kmeansMultidistance():
    def __init__(self, distance="cosine"):
        self.kmeans_model = None
        self.centroids = None
        self.distance = distance

    def predict(self, features):
        """

        @param features:
        @return:
        """
        return self.kmeans_model.assign_clusters(data=features, centroids=self.centroids, distance=self.kmeans_model.distance)

    def load(self, model_path: str):
        """

        @param model_path:
        @return:
        """
        with open(model_path, "rb") as f:
            data = pickle.load(f)
            self.kmeans_model = data["kmeans_model"]
            self.centroids = data["centroids"]
            self.distance = self.kmeans_model.distance

    def save(self, model_path: str):
        """

        @param model_path:
        @return:
        """
        with open(model_path, "wb") as f:
            pickle.dump({
                "kmeans_model": self.kmeans_model,
                "centroids": self.centroids
            }, f)

    def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False):
        """

        @param features:
        @param k:
        @return:
        """
        model = KMeans(k=5, maxiter=maxiter, distance=self.distance, record_heterogeneity=[], verbose=True, seed=123)
        centroids, _ = model.fit(features)
        self.centroids = centroids
        self.kmeans_model = model