Commit 0d218501aca82a10d32d2ee8e51897bc72bd91a9
1 parent
3338829123
Exists in
master
Add an option to specify which file is the list of ids
Showing 1 changed file with 17 additions and 1 deletions Inline Diff
scripts/evaluations/clustering.py
1 | ''' | 1 | ''' |
2 | This script allows the user to evaluate a classification system on new labels using clustering methods. | 2 | This script allows the user to evaluate a classification system on new labels using clustering methods. |
3 | The algorithms are applied on the given latent space (embedding). | 3 | The algorithms are applied on the given latent space (embedding). |
4 | ''' | 4 | ''' |
5 | import argparse | 5 | import argparse |
6 | import numpy as np | 6 | import numpy as np |
7 | import pandas as pd | 7 | import pandas as pd |
8 | import os | 8 | import os |
9 | import time | 9 | import time |
10 | from sklearn.preprocessing import LabelEncoder | 10 | from sklearn.preprocessing import LabelEncoder |
11 | from sklearn.metrics.pairwise import pairwise_distances | 11 | from sklearn.metrics.pairwise import pairwise_distances |
12 | from sklearn.metrics import f1_score | 12 | from sklearn.metrics import f1_score |
13 | from sklearn.cluster import KMeans | 13 | from sklearn.cluster import KMeans |
14 | from sklearn.manifold import TSNE | 14 | from sklearn.manifold import TSNE |
15 | import matplotlib.pyplot as plt | 15 | import matplotlib.pyplot as plt |
16 | 16 | ||
17 | from volia.data_io import read_features,read_lst | 17 | from volia.data_io import read_features,read_lst |
18 | 18 | ||
19 | if __name__ == "__main__": | 19 | if __name__ == "__main__": |
20 | # Argparse | 20 | # Argparse |
21 | parser = argparse.ArgumentParser("Compute clustering on a latent space") | 21 | parser = argparse.ArgumentParser("Compute clustering on a latent space") |
22 | parser.add_argument("features") | 22 | parser.add_argument("features") |
23 | parser.add_argument("utt2", | 23 | parser.add_argument("utt2", |
24 | type=str, | 24 | type=str, |
25 | help="file with [utt] [value]") | 25 | help="file with [utt] [value]") |
26 | parser.add_argument("--idsfrom", | ||
27 | type=str, | ||
28 | default="utt2", | ||
29 | choices=[ | ||
30 | "features", | ||
31 | "utt2" | ||
32 | ], | ||
33 | help="from features or from utt2?") | ||
26 | parser.add_argument("--prefix", | 34 | parser.add_argument("--prefix", |
27 | type=str, | 35 | type=str, |
28 | help="prefix of saved files") | 36 | help="prefix of saved files") |
29 | parser.add_argument("--outdir", | 37 | parser.add_argument("--outdir", |
30 | default=None, | 38 | default=None, |
31 | type=str, | 39 | type=str, |
32 | help="Output directory") | 40 | help="Output directory") |
33 | 41 | ||
34 | args = parser.parse_args() | 42 | args = parser.parse_args() |
35 | 43 | ||
36 | assert args.outdir | 44 | assert args.outdir |
37 | 45 | ||
38 | start = time.time() | 46 | start = time.time() |
39 | 47 | ||
40 | # Load features and utt2 | 48 | # Load features and utt2 |
41 | features = read_features(args.features) | 49 | features = read_features(args.features) |
42 | utt2 = read_lst(args.utt2) | 50 | utt2 = read_lst(args.utt2) |
43 | 51 | ||
44 | ids = list(features.keys()) | 52 | # Take id list |
53 | if args.idsfrom == "features": | ||
54 | ids = list(features.keys()) | ||
55 | elif args.idsfrom == "utt2": | ||
56 | ids = list(utt2.keys()) | ||
57 | else: | ||
58 | print(f"idsfrom is not good: {args.idsfrom}") | ||
59 | exit(1) | ||
60 | |||
45 | feats = np.vstack([ features[id_] for id_ in ids ]) | 61 | feats = np.vstack([ features[id_] for id_ in ids ]) |
46 | classes = [ utt2[id_] for id_ in ids ] | 62 | classes = [ utt2[id_] for id_ in ids ] |
47 | 63 | ||
48 | # Encode labels | 64 | # Encode labels |
49 | le = LabelEncoder() | 65 | le = LabelEncoder() |
50 | labels = le.fit_transform(classes) | 66 | labels = le.fit_transform(classes) |
51 | num_classes = len(le.classes_) | 67 | num_classes = len(le.classes_) |
52 | 68 | ||
53 | # Compute KMEANS clustering on data | 69 | # Compute KMEANS clustering on data |
54 | estimator = KMeans( | 70 | estimator = KMeans( |
55 | n_clusters=num_classes, | 71 | n_clusters=num_classes, |
56 | n_init=100, | 72 | n_init=100, |
57 | tol=10-6, | 73 | tol=10-6, |
58 | algorithm="elkan" | 74 | algorithm="elkan" |
59 | ) | 75 | ) |
60 | estimator.fit(feats) | 76 | estimator.fit(feats) |
61 | print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") | 77 | print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") |
62 | 78 | ||
63 | # contains distance to each cluster for each sample | 79 | # contains distance to each cluster for each sample |
64 | dist_space = estimator.transform(feats) | 80 | dist_space = estimator.transform(feats) |
65 | predictions = np.argmin(dist_space, axis=1) | 81 | predictions = np.argmin(dist_space, axis=1) |
66 | 82 | ||
67 | # gives each cluster a name (considering most represented character) | 83 | # gives each cluster a name (considering most represented character) |
68 | dataframe = pd.DataFrame({ | 84 | dataframe = pd.DataFrame({ |
69 | "label": pd.Series(list(map(lambda x: le.classes_[x], labels))), | 85 | "label": pd.Series(list(map(lambda x: le.classes_[x], labels))), |
70 | "prediction": pd.Series(predictions) | 86 | "prediction": pd.Series(predictions) |
71 | }) | 87 | }) |
72 | 88 | ||
73 | def find_cluster_name_fn(c): | 89 | def find_cluster_name_fn(c): |
74 | mask = dataframe["prediction"] == c | 90 | mask = dataframe["prediction"] == c |
75 | return dataframe[mask]["label"].value_counts(sort=False).idxmax() | 91 | return dataframe[mask]["label"].value_counts(sort=False).idxmax() |
76 | 92 | ||
77 | cluster_names = list(map(find_cluster_name_fn, range(num_classes))) | 93 | cluster_names = list(map(find_cluster_name_fn, range(num_classes))) |
78 | predicted_labels = le.transform( | 94 | predicted_labels = le.transform( |
79 | [cluster_names[pred] for pred in predictions]) | 95 | [cluster_names[pred] for pred in predictions]) |
80 | 96 | ||
81 | # F-measure | 97 | # F-measure |
82 | fscores = f1_score(labels, predicted_labels, average=None) | 98 | fscores = f1_score(labels, predicted_labels, average=None) |
83 | fscores_str = "\n".join(map(lambda i: "{0:25s}: {1:.4f}".format(le.classes_[i], fscores[i]), range(len(fscores)))) | 99 | fscores_str = "\n".join(map(lambda i: "{0:25s}: {1:.4f}".format(le.classes_[i], fscores[i]), range(len(fscores)))) |
84 | print(f"F1-scores for each classes:\n{fscores_str}") | 100 | print(f"F1-scores for each classes:\n{fscores_str}") |
85 | print(f"Global score : {np.mean(fscores)}") | 101 | print(f"Global score : {np.mean(fscores)}") |
86 | with open(os.path.join(args.outdir, args.prefix + "eval_clustering.log"), "w") as fd: | 102 | with open(os.path.join(args.outdir, args.prefix + "eval_clustering.log"), "w") as fd: |
87 | print(f"F1-scores for each classes:\n{fscores_str}", file=fd) | 103 | print(f"F1-scores for each classes:\n{fscores_str}", file=fd) |
88 | print(f"Global score : {np.mean(fscores)}", file=fd) | 104 | print(f"Global score : {np.mean(fscores)}", file=fd) |
89 | 105 | ||
90 | # Process t-SNE and plot | 106 | # Process t-SNE and plot |
91 | tsne_estimator = TSNE() | 107 | tsne_estimator = TSNE() |
92 | embeddings = tsne_estimator.fit_transform(feats) | 108 | embeddings = tsne_estimator.fit_transform(feats) |
93 | print("t-SNE: processed {0} iterations - KL_divergence={1:.4f}".format( | 109 | print("t-SNE: processed {0} iterations - KL_divergence={1:.4f}".format( |
94 | tsne_estimator.n_iter_, tsne_estimator.kl_divergence_)) | 110 | tsne_estimator.n_iter_, tsne_estimator.kl_divergence_)) |
95 | 111 | ||
96 | fig, [axe1, axe2] = plt.subplots(1, 2, figsize=(10, 5)) | 112 | fig, [axe1, axe2] = plt.subplots(1, 2, figsize=(10, 5)) |
97 | for c, name in enumerate(le.classes_): | 113 | for c, name in enumerate(le.classes_): |
98 | c_mask = np.where(labels == c) | 114 | c_mask = np.where(labels == c) |
99 | axe1.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) | 115 | axe1.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) |
100 | 116 | ||
101 | try: | 117 | try: |
102 | id_cluster = cluster_names.index(name) | 118 | id_cluster = cluster_names.index(name) |
103 | except ValueError: | 119 | except ValueError: |
104 | print("WARNING: no cluster found for {}".format(name)) | 120 | print("WARNING: no cluster found for {}".format(name)) |
105 | continue | 121 | continue |
106 | c_mask = np.where(predictions == id_cluster) | 122 | c_mask = np.where(predictions == id_cluster) |
107 | axe2.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) | 123 | axe2.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) |
108 | 124 | ||
109 | axe1.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) | 125 | axe1.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) |
110 | axe1.set_title("true labels") | 126 | axe1.set_title("true labels") |
111 | axe2.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) | 127 | axe2.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) |
112 | axe2.set_title("predicted cluster label") | 128 | axe2.set_title("predicted cluster label") |
113 | 129 | ||
114 | plt.suptitle("Kmeans Clustering") | 130 | plt.suptitle("Kmeans Clustering") |
115 | 131 | ||
116 | loc = os.path.join( | 132 | loc = os.path.join( |
117 | args.outdir, | 133 | args.outdir, |
118 | args.prefix + "kmeans.pdf" | 134 | args.prefix + "kmeans.pdf" |
119 | ) | 135 | ) |
120 | plt.savefig(loc, bbox_inches="tight") | 136 | plt.savefig(loc, bbox_inches="tight") |
121 | plt.close() | 137 | plt.close() |
122 | 138 | ||
123 | print("INFO: figure saved at {}".format(loc)) | 139 | print("INFO: figure saved at {}".format(loc)) |
124 | 140 | ||
125 | end = time.time() | 141 | end = time.time() |
126 | print("program ended in {0:.2f} seconds".format(end-start)) | 142 | print("program ended in {0:.2f} seconds".format(end-start)) |