Commit 0d218501aca82a10d32d2ee8e51897bc72bd91a9

Authored by Mathias
1 parent 3338829123
Exists in master

Add an option to specify which file is the list of ids

Showing 1 changed file with 17 additions and 1 deletions Inline Diff

scripts/evaluations/clustering.py
1 ''' 1 '''
2 This script allows the user to evaluate a classification system on new labels using clustering methods. 2 This script allows the user to evaluate a classification system on new labels using clustering methods.
3 The algorithms are applied on the given latent space (embedding). 3 The algorithms are applied on the given latent space (embedding).
4 ''' 4 '''
5 import argparse 5 import argparse
6 import numpy as np 6 import numpy as np
7 import pandas as pd 7 import pandas as pd
8 import os 8 import os
9 import time 9 import time
10 from sklearn.preprocessing import LabelEncoder 10 from sklearn.preprocessing import LabelEncoder
11 from sklearn.metrics.pairwise import pairwise_distances 11 from sklearn.metrics.pairwise import pairwise_distances
12 from sklearn.metrics import f1_score 12 from sklearn.metrics import f1_score
13 from sklearn.cluster import KMeans 13 from sklearn.cluster import KMeans
14 from sklearn.manifold import TSNE 14 from sklearn.manifold import TSNE
15 import matplotlib.pyplot as plt 15 import matplotlib.pyplot as plt
16 16
17 from volia.data_io import read_features,read_lst 17 from volia.data_io import read_features,read_lst
18 18
19 if __name__ == "__main__": 19 if __name__ == "__main__":
20 # Argparse 20 # Argparse
21 parser = argparse.ArgumentParser("Compute clustering on a latent space") 21 parser = argparse.ArgumentParser("Compute clustering on a latent space")
22 parser.add_argument("features") 22 parser.add_argument("features")
23 parser.add_argument("utt2", 23 parser.add_argument("utt2",
24 type=str, 24 type=str,
25 help="file with [utt] [value]") 25 help="file with [utt] [value]")
26 parser.add_argument("--idsfrom",
27 type=str,
28 default="utt2",
29 choices=[
30 "features",
31 "utt2"
32 ],
33 help="from features or from utt2?")
26 parser.add_argument("--prefix", 34 parser.add_argument("--prefix",
27 type=str, 35 type=str,
28 help="prefix of saved files") 36 help="prefix of saved files")
29 parser.add_argument("--outdir", 37 parser.add_argument("--outdir",
30 default=None, 38 default=None,
31 type=str, 39 type=str,
32 help="Output directory") 40 help="Output directory")
33 41
34 args = parser.parse_args() 42 args = parser.parse_args()
35 43
36 assert args.outdir 44 assert args.outdir
37 45
38 start = time.time() 46 start = time.time()
39 47
40 # Load features and utt2 48 # Load features and utt2
41 features = read_features(args.features) 49 features = read_features(args.features)
42 utt2 = read_lst(args.utt2) 50 utt2 = read_lst(args.utt2)
43 51
44 ids = list(features.keys()) 52 # Take id list
53 if args.idsfrom == "features":
54 ids = list(features.keys())
55 elif args.idsfrom == "utt2":
56 ids = list(utt2.keys())
57 else:
58 print(f"idsfrom is not good: {args.idsfrom}")
59 exit(1)
60
45 feats = np.vstack([ features[id_] for id_ in ids ]) 61 feats = np.vstack([ features[id_] for id_ in ids ])
46 classes = [ utt2[id_] for id_ in ids ] 62 classes = [ utt2[id_] for id_ in ids ]
47 63
48 # Encode labels 64 # Encode labels
49 le = LabelEncoder() 65 le = LabelEncoder()
50 labels = le.fit_transform(classes) 66 labels = le.fit_transform(classes)
51 num_classes = len(le.classes_) 67 num_classes = len(le.classes_)
52 68
53 # Compute KMEANS clustering on data 69 # Compute KMEANS clustering on data
54 estimator = KMeans( 70 estimator = KMeans(
55 n_clusters=num_classes, 71 n_clusters=num_classes,
56 n_init=100, 72 n_init=100,
57 tol=10-6, 73 tol=10-6,
58 algorithm="elkan" 74 algorithm="elkan"
59 ) 75 )
60 estimator.fit(feats) 76 estimator.fit(feats)
61 print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") 77 print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}")
62 78
63 # contains distance to each cluster for each sample 79 # contains distance to each cluster for each sample
64 dist_space = estimator.transform(feats) 80 dist_space = estimator.transform(feats)
65 predictions = np.argmin(dist_space, axis=1) 81 predictions = np.argmin(dist_space, axis=1)
66 82
67 # gives each cluster a name (considering most represented character) 83 # gives each cluster a name (considering most represented character)
68 dataframe = pd.DataFrame({ 84 dataframe = pd.DataFrame({
69 "label": pd.Series(list(map(lambda x: le.classes_[x], labels))), 85 "label": pd.Series(list(map(lambda x: le.classes_[x], labels))),
70 "prediction": pd.Series(predictions) 86 "prediction": pd.Series(predictions)
71 }) 87 })
72 88
73 def find_cluster_name_fn(c): 89 def find_cluster_name_fn(c):
74 mask = dataframe["prediction"] == c 90 mask = dataframe["prediction"] == c
75 return dataframe[mask]["label"].value_counts(sort=False).idxmax() 91 return dataframe[mask]["label"].value_counts(sort=False).idxmax()
76 92
77 cluster_names = list(map(find_cluster_name_fn, range(num_classes))) 93 cluster_names = list(map(find_cluster_name_fn, range(num_classes)))
78 predicted_labels = le.transform( 94 predicted_labels = le.transform(
79 [cluster_names[pred] for pred in predictions]) 95 [cluster_names[pred] for pred in predictions])
80 96
81 # F-measure 97 # F-measure
82 fscores = f1_score(labels, predicted_labels, average=None) 98 fscores = f1_score(labels, predicted_labels, average=None)
83 fscores_str = "\n".join(map(lambda i: "{0:25s}: {1:.4f}".format(le.classes_[i], fscores[i]), range(len(fscores)))) 99 fscores_str = "\n".join(map(lambda i: "{0:25s}: {1:.4f}".format(le.classes_[i], fscores[i]), range(len(fscores))))
84 print(f"F1-scores for each classes:\n{fscores_str}") 100 print(f"F1-scores for each classes:\n{fscores_str}")
85 print(f"Global score : {np.mean(fscores)}") 101 print(f"Global score : {np.mean(fscores)}")
86 with open(os.path.join(args.outdir, args.prefix + "eval_clustering.log"), "w") as fd: 102 with open(os.path.join(args.outdir, args.prefix + "eval_clustering.log"), "w") as fd:
87 print(f"F1-scores for each classes:\n{fscores_str}", file=fd) 103 print(f"F1-scores for each classes:\n{fscores_str}", file=fd)
88 print(f"Global score : {np.mean(fscores)}", file=fd) 104 print(f"Global score : {np.mean(fscores)}", file=fd)
89 105
90 # Process t-SNE and plot 106 # Process t-SNE and plot
91 tsne_estimator = TSNE() 107 tsne_estimator = TSNE()
92 embeddings = tsne_estimator.fit_transform(feats) 108 embeddings = tsne_estimator.fit_transform(feats)
93 print("t-SNE: processed {0} iterations - KL_divergence={1:.4f}".format( 109 print("t-SNE: processed {0} iterations - KL_divergence={1:.4f}".format(
94 tsne_estimator.n_iter_, tsne_estimator.kl_divergence_)) 110 tsne_estimator.n_iter_, tsne_estimator.kl_divergence_))
95 111
96 fig, [axe1, axe2] = plt.subplots(1, 2, figsize=(10, 5)) 112 fig, [axe1, axe2] = plt.subplots(1, 2, figsize=(10, 5))
97 for c, name in enumerate(le.classes_): 113 for c, name in enumerate(le.classes_):
98 c_mask = np.where(labels == c) 114 c_mask = np.where(labels == c)
99 axe1.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) 115 axe1.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None)
100 116
101 try: 117 try:
102 id_cluster = cluster_names.index(name) 118 id_cluster = cluster_names.index(name)
103 except ValueError: 119 except ValueError:
104 print("WARNING: no cluster found for {}".format(name)) 120 print("WARNING: no cluster found for {}".format(name))
105 continue 121 continue
106 c_mask = np.where(predictions == id_cluster) 122 c_mask = np.where(predictions == id_cluster)
107 axe2.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) 123 axe2.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None)
108 124
109 axe1.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) 125 axe1.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35))
110 axe1.set_title("true labels") 126 axe1.set_title("true labels")
111 axe2.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) 127 axe2.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35))
112 axe2.set_title("predicted cluster label") 128 axe2.set_title("predicted cluster label")
113 129
114 plt.suptitle("Kmeans Clustering") 130 plt.suptitle("Kmeans Clustering")
115 131
116 loc = os.path.join( 132 loc = os.path.join(
117 args.outdir, 133 args.outdir,
118 args.prefix + "kmeans.pdf" 134 args.prefix + "kmeans.pdf"
119 ) 135 )
120 plt.savefig(loc, bbox_inches="tight") 136 plt.savefig(loc, bbox_inches="tight")
121 plt.close() 137 plt.close()
122 138
123 print("INFO: figure saved at {}".format(loc)) 139 print("INFO: figure saved at {}".format(loc))
124 140
125 end = time.time() 141 end = time.time()
126 print("program ended in {0:.2f} seconds".format(end-start)) 142 print("program ended in {0:.2f} seconds".format(end-start))