kmeans_mahalanobis.py 6.37 KB
from sklearn.cluster import KMeans
import pickle
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from abstract_clustering import AbstractClustering

class kmeansMahalanobis():
    def __init__(self, constrained: bool = False):
        """

        """
        self.C = None
        self.L = None
        self.K = None
        self.constrained = constrained

    def predict(self, features):
        """

        @param features:
        @return:
        """
        N = features.shape[0]
        distances = np.zeros((N, self.K))
        for n in range(N):
            for k in range(self.K):
                distances[n][k] = self._dist(features[n], self.C[k], self.L[k])
        closest_cluster = np.argmin(distances, axis=1)
        return closest_cluster

    def load(self, model_path):
        """

        @param model_path:
        @return:
        """
        data = None
        with open(model_path, "rb") as f:
            data = pickle.load(f)
        if data is None:
            raise Exception("Le modèle n'a pas pu être chargé")
        else:
            self.C = data["C"]
            self.L = data["L"]
            self.K = data["K"]
            self.constrained = data["constrained"]

    def save(self, modelpath: str):
        """

        @param modelpath:
        @return:
        """
        data = {
            "C": self.C,
            "L": self.L,
            "K": self.K,
            "constrained": self.constrained
        }
        with open(modelpath, "wb") as f:
            pickle.dump(data, f)

    def fit(self, features, k: int, tol: float, ninit: int, maxiter: int=300, debug: bool=False):
        results = []
        for i in range(ninit):
            results.append(self._train(features, k, tol, maxiter, debug))
        losses = [v["loss"] for v in results]
        best = results[losses.index(min(losses))]
        if debug:
            print(f"best: {best['loss']} loss")
        self.C = best["C"]
        self.L = best["L"]
        self.K = best["K"]

    def _initialize_model(self, X, number_clusters):
        d = X.shape[1]
        C = X[np.random.choice(X.shape[0], number_clusters)]
        L = np.zeros((number_clusters, d, d))
        for k in range(number_clusters):
            L[k] = np.identity(d)
        return C, L

    def _dist(self, a, b, l):
        '''
        Distance euclidienne with mahalanobis
        '''
        a = np.reshape(a, (-1, 1))
        b = np.reshape(b, (-1, 1))
        result = np.transpose(a - b).dot(l).dot(a - b)[0][0]
        return result

    def _plot_iteration(self, iteration, points, clusters, centers):
        fig = plt.figure()
        ax = fig.add_subplot(111)
        scatter = ax.scatter(points[:, 0], points[:, 1], c=clusters, s=50)

        #for center in centers:
        #    ax.scatter(center[0], center[1], s=50, c='red', marker='+')
        ax.scatter(centers[:, 0], centers[:, 1], s=50, c='red', marker='+')

        ax.set_xlabel('x')
        ax.set_ylabel('y')
        plt.colorbar(scatter)
        #plt.ylim(0, 1)
        #plt.xlim(0, 1)
        plt.savefig("test_" + str(iteration) + ".pdf")

    def _train(self, features, K: int, tol: float, maxiter: int, debug: bool=False):
        X = features
        N = X.shape[0]
        d = X.shape[1]

        C, L = self._initialize_model(X, K)
        self.C = C
        self.L = L
        self.K = K

        end_algo = False
        i = 0
        while not end_algo:
            if debug:
                print("Iteration: ", i)

            # Calcul matrix distance
            distances = np.zeros((N, self.K))

            for n in range(N):
                for k in range(self.K):
                    distances[n][k] = self._dist(X[n], self.C[k], self.L[k])

            closest_cluster = np.argmin(distances, axis=1)

            loss = np.sum(distances[np.arange(len(distances)), closest_cluster])
            if debug:
                print(f"loss {loss}")


            # -- Debug tool ----------------------
            if debug and i % 1 == 0:
                # TSNE if needed
                X_embedded = np.concatenate((X, self.C), axis=0)
                if d > 2:
                    X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, self.C), axis=0))

                # Then plot
                self._plot_iteration(
                    i,
                    X_embedded[:X.shape[0]],
                    closest_cluster,
                    X_embedded[X.shape[0]:]
                )
            # ------------------------------------

            old_c = self.C.copy()
            for k in range(self.K):
                # Find subset of X with values closed to the centroid c_k.
                X_sub = np.where(closest_cluster == k)
                X_sub = np.take(X, X_sub[0], axis=0)
                if X_sub.shape[0] == 0:
                    continue

                C_new = np.mean(X_sub, axis=0)

                # -- COMPUTE NEW LAMBDA (here named K) --
                K_new = np.zeros((self.L.shape[1], self.L.shape[2]))
                tmp = np.zeros((self.L.shape[1], self.L.shape[2]))
                for x in X_sub:
                    x = np.reshape(x, (-1, 1))
                    c_tmp = np.reshape(C_new, (-1, 1))
                    #K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose())

                    tmp = tmp + (x - c_tmp).dot((x - c_tmp).transpose())
                if self.constrained:
                    K_new = (tmp / X_sub.shape[0]) / np.power(np.linalg.det((tmp / X_sub.shape[0])), 1/d)
                else:
                    K_new = tmp / X_sub.shape[0]
                K_new = np.linalg.pinv(K_new)

                #if end_algo and (not (self.C[k] == C_new).all()):  # If the same stop
                #    end_algo = False
                self.C[k] = C_new
                self.L[k] = K_new


            diff = np.sum(np.absolute((self.C - old_c) / old_c * 100))
            if diff > tol:
                end_algo = False
                if debug:
                    print(f"{diff}")
            else:
                if debug:
                    print(f"Tolerance threshold {tol} reached with diff {diff}")
                end_algo = True

            i = i + 1
            if i > maxiter:
                end_algo = True
                if debug:
                    print(f"Iteration {maxiter} reached")
        return {
            "loss": loss,
            "C": self.C,
            "K": self.K,
            "L": self.L
        }