kmeans.py 931 Bytes
from sklearn.cluster import KMeans
import pickle
from abstract_clustering import AbstractClustering

class kmeans():
    def __init__(self):
        self.kmeans_model = None

    def predict(self, features):
        """

        @param features:
        @return:
        """
        return self.kmeans_model.predict(features)

    def load(self, model_path: str):
        """

        @param model_path:
        @return:
        """
        with open(model_path, "rb") as f:
            self.kmeans_model = pickle.load(f)

    def save(self, model_path: str):
        """

        @param model_path:
        @return:
        """
        with open(model_path, "wb") as f:
            pickle.dump(self.kmeans_model, f)

    def fit(self, features, k: int):
        """

        @param features:
        @param k:
        @return:
        """
        self.kmeans_model = KMeans(n_clusters=k, n_init=10, random_state=0).fit(features)