Commit e403ed5fb6202dae56d47815d5961cced00f1c85

Authored by Mathias
1 parent 11ee97e2cc
Exists in master

Add a script that allow user to evaluate a representation using classification labels.

Showing 1 changed file with 126 additions and 0 deletions Inline Diff

scripts/evaluations/clustering.py
File was created 1 '''
2 This script allows the user to evaluate a classification system on new labels using clustering methods.
3 The algorithms are applied on the given latent space (embedding).
4 '''
5 import argparse
6 import numpy as np
7 import pandas as pd
8 import os
9 import time
10 from sklearn.preprocessing import LabelEncoder
11 from sklearn.metrics.pairwise import pairwise_distances
12 from sklearn.metrics import f1_score
13 from sklearn.cluster import KMeans
14 from sklearn.manifold import TSNE
15 import matplotlib.pyplot as plt
16
17 from volia.data_io import read_features,read_lst
18
19 if __name__ == "__main__":
20 # Argparse
21 parser = argparse.ArgumentParser("Compute clustering on a latent space")
22 parser.add_argument("features")
23 parser.add_argument("utt2",
24 type=str,
25 help="file with [utt] [value]")
26 parser.add_argument("--prefix",
27 type=str,
28 help="prefix of saved files")
29 parser.add_argument("--outdir",
30 default=None,
31 type=str,
32 help="Output directory")
33
34 args = parser.parse_args()
35
36 assert args.outdir
37
38 start = time.time()
39
40 # Load features and utt2
41 features = read_features(args.features)
42 utt2 = read_lst(args.utt2)
43
44 ids = list(features.keys())
45 feats = np.vstack([ features[id_] for id_ in ids ])
46 classes = [ utt2[id_] for id_ in ids ]
47
48 # Encode labels
49 le = LabelEncoder()
50 labels = le.fit_transform(classes)
51 num_classes = len(le.classes_)
52
53 # Compute KMEANS clustering on data
54 estimator = KMeans(
55 n_clusters=num_classes,
56 n_init=100,
57 tol=10-6,
58 algorithm="elkan"
59 )
60 estimator.fit(feats)
61 print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}")
62
63 # contains distance to each cluster for each sample
64 dist_space = estimator.transform(feats)
65 predictions = np.argmin(dist_space, axis=1)
66
67 # gives each cluster a name (considering most represented character)
68 dataframe = pd.DataFrame({
69 "label": pd.Series(list(map(lambda x: le.classes_[x], labels))),
70 "prediction": pd.Series(predictions)
71 })
72
73 def find_cluster_name_fn(c):
74 mask = dataframe["prediction"] == c
75 return dataframe[mask]["label"].value_counts(sort=False).idxmax()
76
77 cluster_names = list(map(find_cluster_name_fn, range(num_classes)))
78 predicted_labels = le.transform(
79 [cluster_names[pred] for pred in predictions])
80
81 # F-measure
82 fscores = f1_score(labels, predicted_labels, average=None)
83 fscores_str = "\n".join(map(lambda i: "{0:25s}: {1:.4f}".format(le.classes_[i], fscores[i]), range(len(fscores))))
84 print(f"F1-scores for each classes:\n{fscores_str}")
85 print(f"Global score : {np.mean(fscores)}")
86 with open(os.path.join(args.outdir, args.prefix + "eval_clustering.log"), "w") as fd:
87 print(f"F1-scores for each classes:\n{fscores_str}", file=fd)
88 print(f"Global score : {np.mean(fscores)}", file=fd)
89
90 # Process t-SNE and plot
91 tsne_estimator = TSNE()
92 embeddings = tsne_estimator.fit_transform(feats)
93 print("t-SNE: processed {0} iterations - KL_divergence={1:.4f}".format(
94 tsne_estimator.n_iter_, tsne_estimator.kl_divergence_))
95
96 fig, [axe1, axe2] = plt.subplots(1, 2, figsize=(10, 5))
97 for c, name in enumerate(le.classes_):
98 c_mask = np.where(labels == c)
99 axe1.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None)
100
101 try:
102 id_cluster = cluster_names.index(name)
103 except ValueError:
104 print("WARNING: no cluster found for {}".format(name))
105 continue
106 c_mask = np.where(predictions == id_cluster)
107 axe2.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None)
108
109 axe1.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35))
110 axe1.set_title("true labels")
111 axe2.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35))
112 axe2.set_title("predicted cluster label")
113
114 plt.suptitle("Kmeans Clustering")
115
116 loc = os.path.join(
117 args.outdir,
118 args.prefix + "kmeans.pdf"
119 )
120 plt.savefig(loc, bbox_inches="tight")
121 plt.close()
122
123 print("INFO: figure saved at {}".format(loc))
124
125 end = time.time()
126 print("program ended in {0:.2f} seconds".format(end-start))