Commit a79344696dda8c5c86f4fb4cc8d0ce6a149870f2

Authored by Mathias Quillot
1 parent 8797ed0e3d
Exists in master

Implementation of cluster kmeans with constrained mahalanobis

Showing 1 changed file with 125 additions and 0 deletions Side-by-side Diff

bin/cluster_kmeans_constrained_mahalanobis.py
  1 +'''
  2 +Un petit test pour faire du clustering
  3 +avec une distance de mahalanobis
  4 +From paper:
  5 +Convergence problems of Mahalanobis distance-based k-means clustering,
  6 +Itshak Lapidot
  7 +
  8 +Just one thing: Column and lines are inversed in this script.
  9 +TODO: Random selection from the set
  10 +'''
  11 +
  12 +import matplotlib.pyplot as plt
  13 +import numpy as np
  14 +# from sklearn.manifold import TSNE
  15 +
  16 +N = 50 # Number of individus
  17 +d = 2 # Number of dimensions
  18 +K = 3 # number of clusters
  19 +
  20 +
  21 +def initialize_model(X, number_clusters):
  22 + rand = np.random.choice(X.shape[0], number_clusters)
  23 + C = X[rand]
  24 + L = np.zeros((K, d, d))
  25 + for k in range(K):
  26 + L[k] = np.identity(d)
  27 + return C, L
  28 +
  29 +
  30 +X = np.random.rand(N, d) # Features
  31 +C, L = initialize_model(X, K)
  32 +
  33 +
  34 +def gaussian(x, c, l):
  35 + '''
  36 + G function
  37 + '''
  38 + p1 = np.power(np.linalg.det(l), 0.5) / np.power((2 * np.pi), d/2)
  39 + p2 = np.exp(np.transpose(x - c).dot(1/2).dot(l).dot(x - c))
  40 + return (p1 * p2).reshape(1)
  41 +
  42 +
  43 +def dist(a, b, l):
  44 + '''
  45 + Distance euclidienne
  46 + '''
  47 + a = np.reshape(a, (-1, 1))
  48 + b = np.reshape(b, (-1, 1))
  49 + result = np.transpose(a - b).dot(l).dot(a-b)[0][0]
  50 + return result
  51 +
  52 +
  53 +def plot_iteration(iteration, points, clusters, centers):
  54 + fig = plt.figure()
  55 + ax = fig.add_subplot(111)
  56 + scatter = ax.scatter(points[:, 0], points[:, 1], c=clusters, s=50)
  57 + for i, j in centers:
  58 + ax.scatter(i, j, s=50, c='red', marker='+')
  59 + ax.set_xlabel('x')
  60 + ax.set_ylabel('y')
  61 + plt.colorbar(scatter)
  62 + plt.ylim(0, 1)
  63 + plt.xlim(0, 1)
  64 + plt.savefig("test_" + str(iteration) + ".pdf")
  65 +
  66 +
  67 +end_algo = False
  68 +i = 0
  69 +while not end_algo:
  70 + if i == 10:
  71 + exit(1)
  72 + print("Iteration: ", i)
  73 + # -- Calcul matrix distance
  74 + distances = np.zeros((N, K))
  75 +
  76 + # -- Calcul closest cluster
  77 + for n in range(N):
  78 + for k in range(K):
  79 + distances[n][k] = dist(X[n], C[k], L[k])
  80 + closest_cluster = np.argmin(distances, axis=1)
  81 +
  82 + # -- Debug tool ----------------------
  83 + if i % 1 == 0:
  84 + # TSNE
  85 + X_embedded = np.concatenate((X, C), axis=0)
  86 + # X_embedded = TSNE(n_components=2).fit_transform(np.concatenate((X, C), axis=0))
  87 + # Then plot
  88 + plot_iteration(
  89 + i,
  90 + X_embedded[:X.shape[0]],
  91 + closest_cluster,
  92 + X_embedded[X.shape[0]:]
  93 + )
  94 + # ------------------------------------
  95 +
  96 + end_algo = True
  97 + for k in range(K):
  98 + # Find subset of X with values closed to the centroid c_k.
  99 + X_sub = np.where(closest_cluster == k)
  100 + X_sub = np.take(X, X_sub[0], axis=0)
  101 + np.mean(X_sub, axis=0)
  102 + C_new = np.mean(X_sub, axis=0)
  103 +
  104 + # -- COMPUTE NEW LAMBDA (here named K) --
  105 + K_new = np.zeros((L.shape[1], L.shape[2]))
  106 + for x in X_sub:
  107 + x = np.reshape(x, (-1, 1))
  108 + c_tmp = np.reshape(C_new, (-1, 1))
  109 + K_new = K_new + (x - c_tmp).dot((x - c_tmp).transpose())
  110 + K_new = K_new / X_sub.shape[0]
  111 + K_new = K_new / np.power(np.linalg.det(K_new), 1/d)
  112 + K_new = np.linalg.inv(K_new)
  113 +
  114 + if end_algo and (not (C[k] == C_new).all()): # If the same stop
  115 + end_algo = False
  116 + C[k] = C_new
  117 + L[k] = K_new
  118 + i = i + 1
  119 +
  120 +plot_iteration(
  121 + i,
  122 + X_embedded[:X.shape[0]],
  123 + closest_cluster,
  124 + X_embedded[X.shape[0]:]
  125 +)