Blame view
volia/clustering_modules/kmeans.py
931 Bytes
9191399c3 Clustering and ev... |
1 2 3 4 5 6 7 8 9 10 |
from sklearn.cluster import KMeans import pickle from abstract_clustering import AbstractClustering class kmeans(): def __init__(self): self.kmeans_model = None def predict(self, features): |
4152e83df Addind kmeans mah... |
11 12 13 14 15 |
""" @param features: @return: """ |
9191399c3 Clustering and ev... |
16 |
return self.kmeans_model.predict(features) |
4152e83df Addind kmeans mah... |
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
def load(self, model_path: str): """ @param model_path: @return: """ with open(model_path, "rb") as f: self.kmeans_model = pickle.load(f) def save(self, model_path: str): """ @param model_path: @return: """ with open(model_path, "wb") as f: pickle.dump(self.kmeans_model, f) def fit(self, features, k: int): """ @param features: @param k: @return: """ self.kmeans_model = KMeans(n_clusters=k, n_init=10, random_state=0).fit(features) |