Commit 5eb3a27646632055665ec5fde59add00dbe7547a

Authored by Mathias Quillot
1 parent cce036f22f
Exists in master

Implementation of kmeans with mahalanobis distance

Showing 1 changed file with 113 additions and 0 deletions Side-by-side Diff

bin/cluster_kmeans_mahalanobis.py
  1 +'''
  2 +Un petit test pour faire du clustering
  3 +avec une distance de mahalanobis
  4 +From paper:
  5 +Convergence problems of Mahalanobis distance-based k-means clustering
  6 +
  7 +Just one thing: Column and lines are inversed in this script.
  8 +'''
  9 +
  10 +import matplotlib.pyplot as plt
  11 +import numpy as np
  12 +# from sklearn.manifold import TSNE
  13 +
  14 +
  15 +N = 50 # Number of individus
  16 +d = 2 # Number of dimensions
  17 +K = 3 # number of clusters
  18 +
  19 +
  20 +def initialize_model(X, number_clusters):
  21 + C = X[np.random.choice(X.shape[0], number_clusters)]
  22 + L = np.zeros((K, d, d))
  23 + for k in range(K):
  24 + L[k] = np.identity(d)
  25 + return C, L
  26 +
  27 +
  28 +X = np.random.rand(N, d) # Features
  29 +
  30 +C, L = initialize_model(X, K)
  31 +
  32 +
  33 +def dist(a, b, l):
  34 + '''
  35 + Distance euclidienne
  36 + '''
  37 + a = np.reshape(a, (-1, 1))
  38 + b = np.reshape(b, (-1, 1))
  39 + result = np.transpose(a - b).dot(l).dot(a-b)[0][0]
  40 + return result
  41 +
  42 +
  43 +def plot_iteration(iteration, points, clusters, centers):
  44 + fig = plt.figure()
  45 + ax = fig.add_subplot(111)
  46 + scatter = ax.scatter(points[:, 0], points[:, 1], c=clusters, s=50)
  47 + for i, j in centers:
  48 + ax.scatter(i, j, s=50, c='red', marker='+')
  49 + ax.set_xlabel('x')
  50 + ax.set_ylabel('y')
  51 + plt.colorbar(scatter)
  52 + plt.ylim(0, 1)
  53 + plt.xlim(0, 1)
  54 + plt.savefig("test_" + str(iteration) + ".pdf")
  55 +
  56 +
  57 +end_algo = False
  58 +i = 0
  59 +while not end_algo:
  60 + if i == 10:
  61 + exit(1)
  62 + print("Iteration: ", i)
  63 + # Calcul matrix distance
  64 + distances = np.zeros((N, K))
  65 +
  66 + for n in range(N):
  67 + for k in range(K):
  68 + distances[n][k] = dist(X[n], C[k], L[k])
  69 + print(distances)
  70 + closest_cluster = np.argmin(distances, axis=1)
  71 + if i % 1 == 0:
  72 + # -- Debug tool ----------------------
  73 + # TSNE
  74 + X_embedded = np.concatenate((X, C), axis=0)
  75 + # X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, C), axis=0))
  76 + # Then plot
  77 + plot_iteration(
  78 + i,
  79 + X_embedded[:X.shape[0]],
  80 + closest_cluster,
  81 + X_embedded[X.shape[0]:]
  82 + )
  83 + # ------------------------------------
  84 +
  85 + end_algo = True
  86 + for k in range(K):
  87 + # Find subset of X with values closed to the centroid c_k.
  88 + X_sub = np.where(closest_cluster == k)
  89 + X_sub = np.take(X, X_sub[0], axis=0)
  90 + np.mean(X_sub, axis=0)
  91 + C_new = np.mean(X_sub, axis=0)
  92 +
  93 + # -- COMPUTE NEW LAMBDA (here named K) --
  94 + K_new = np.zeros((L.shape[1], L.shape[2]))
  95 + for x in X_sub:
  96 + x = np.reshape(x, (-1, 1))
  97 + c_tmp = np.reshape(C_new, (-1, 1))
  98 + K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose())
  99 + K_new = K_new / X_sub.shape[0]
  100 + K_new = np.linalg.inv(K_new)
  101 +
  102 + if end_algo and (not (C[k] == C_new).all()): # If the same stop
  103 + end_algo = False
  104 + C[k] = C_new
  105 + L[k] = K_new
  106 + i = i + 1
  107 +
  108 +plot_iteration(
  109 + i,
  110 + X_embedded[:X.shape[0]],
  111 + closest_cluster,
  112 + X_embedded[X.shape[0]:]
  113 +)