Commit 85d6f0944e7bad5842e2a9df67e604c012ce85a4
1 parent
0d218501ac
Exists in
master
Add default value to prefix
Showing 1 changed file with 2 additions and 1 deletions Inline Diff
scripts/evaluations/clustering.py
1 | ''' | 1 | ''' |
2 | This script allows the user to evaluate a classification system on new labels using clustering methods. | 2 | This script allows the user to evaluate a classification system on new labels using clustering methods. |
3 | The algorithms are applied on the given latent space (embedding). | 3 | The algorithms are applied on the given latent space (embedding). |
4 | ''' | 4 | ''' |
5 | import argparse | 5 | import argparse |
6 | import numpy as np | 6 | import numpy as np |
7 | import pandas as pd | 7 | import pandas as pd |
8 | import os | 8 | import os |
9 | import time | 9 | import time |
10 | from sklearn.preprocessing import LabelEncoder | 10 | from sklearn.preprocessing import LabelEncoder |
11 | from sklearn.metrics.pairwise import pairwise_distances | 11 | from sklearn.metrics.pairwise import pairwise_distances |
12 | from sklearn.metrics import f1_score | 12 | from sklearn.metrics import f1_score |
13 | from sklearn.cluster import KMeans | 13 | from sklearn.cluster import KMeans |
14 | from sklearn.manifold import TSNE | 14 | from sklearn.manifold import TSNE |
15 | import matplotlib.pyplot as plt | 15 | import matplotlib.pyplot as plt |
16 | 16 | ||
17 | from volia.data_io import read_features,read_lst | 17 | from volia.data_io import read_features,read_lst |
18 | 18 | ||
19 | if __name__ == "__main__": | 19 | if __name__ == "__main__": |
20 | # Argparse | 20 | # Argparse |
21 | parser = argparse.ArgumentParser("Compute clustering on a latent space") | 21 | parser = argparse.ArgumentParser("Compute clustering on a latent space") |
22 | parser.add_argument("features") | 22 | parser.add_argument("features") |
23 | parser.add_argument("utt2", | 23 | parser.add_argument("utt2", |
24 | type=str, | 24 | type=str, |
25 | help="file with [utt] [value]") | 25 | help="file with [utt] [value]") |
26 | parser.add_argument("--idsfrom", | 26 | parser.add_argument("--idsfrom", |
27 | type=str, | 27 | type=str, |
28 | default="utt2", | 28 | default="utt2", |
29 | choices=[ | 29 | choices=[ |
30 | "features", | 30 | "features", |
31 | "utt2" | 31 | "utt2" |
32 | ], | 32 | ], |
33 | help="from features or from utt2?") | 33 | help="from features or from utt2?") |
34 | parser.add_argument("--prefix", | 34 | parser.add_argument("--prefix", |
35 | default="", | ||
35 | type=str, | 36 | type=str, |
36 | help="prefix of saved files") | 37 | help="prefix of saved files") |
37 | parser.add_argument("--outdir", | 38 | parser.add_argument("--outdir", |
38 | default=None, | 39 | default=None, |
39 | type=str, | 40 | type=str, |
40 | help="Output directory") | 41 | help="Output directory") |
41 | 42 | ||
42 | args = parser.parse_args() | 43 | args = parser.parse_args() |
43 | 44 | ||
44 | assert args.outdir | 45 | assert args.outdir |
45 | 46 | ||
46 | start = time.time() | 47 | start = time.time() |
47 | 48 | ||
48 | # Load features and utt2 | 49 | # Load features and utt2 |
49 | features = read_features(args.features) | 50 | features = read_features(args.features) |
50 | utt2 = read_lst(args.utt2) | 51 | utt2 = read_lst(args.utt2) |
51 | 52 | ||
52 | # Take id list | 53 | # Take id list |
53 | if args.idsfrom == "features": | 54 | if args.idsfrom == "features": |
54 | ids = list(features.keys()) | 55 | ids = list(features.keys()) |
55 | elif args.idsfrom == "utt2": | 56 | elif args.idsfrom == "utt2": |
56 | ids = list(utt2.keys()) | 57 | ids = list(utt2.keys()) |
57 | else: | 58 | else: |
58 | print(f"idsfrom is not good: {args.idsfrom}") | 59 | print(f"idsfrom is not good: {args.idsfrom}") |
59 | exit(1) | 60 | exit(1) |
60 | 61 | ||
61 | feats = np.vstack([ features[id_] for id_ in ids ]) | 62 | feats = np.vstack([ features[id_] for id_ in ids ]) |
62 | classes = [ utt2[id_] for id_ in ids ] | 63 | classes = [ utt2[id_] for id_ in ids ] |
63 | 64 | ||
64 | # Encode labels | 65 | # Encode labels |
65 | le = LabelEncoder() | 66 | le = LabelEncoder() |
66 | labels = le.fit_transform(classes) | 67 | labels = le.fit_transform(classes) |
67 | num_classes = len(le.classes_) | 68 | num_classes = len(le.classes_) |
68 | 69 | ||
69 | # Compute KMEANS clustering on data | 70 | # Compute KMEANS clustering on data |
70 | estimator = KMeans( | 71 | estimator = KMeans( |
71 | n_clusters=num_classes, | 72 | n_clusters=num_classes, |
72 | n_init=100, | 73 | n_init=100, |
73 | tol=10-6, | 74 | tol=10-6, |
74 | algorithm="elkan" | 75 | algorithm="elkan" |
75 | ) | 76 | ) |
76 | estimator.fit(feats) | 77 | estimator.fit(feats) |
77 | print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") | 78 | print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") |
78 | 79 | ||
79 | # contains distance to each cluster for each sample | 80 | # contains distance to each cluster for each sample |
80 | dist_space = estimator.transform(feats) | 81 | dist_space = estimator.transform(feats) |
81 | predictions = np.argmin(dist_space, axis=1) | 82 | predictions = np.argmin(dist_space, axis=1) |
82 | 83 | ||
83 | # gives each cluster a name (considering most represented character) | 84 | # gives each cluster a name (considering most represented character) |
84 | dataframe = pd.DataFrame({ | 85 | dataframe = pd.DataFrame({ |
85 | "label": pd.Series(list(map(lambda x: le.classes_[x], labels))), | 86 | "label": pd.Series(list(map(lambda x: le.classes_[x], labels))), |
86 | "prediction": pd.Series(predictions) | 87 | "prediction": pd.Series(predictions) |
87 | }) | 88 | }) |
88 | 89 | ||
89 | def find_cluster_name_fn(c): | 90 | def find_cluster_name_fn(c): |
90 | mask = dataframe["prediction"] == c | 91 | mask = dataframe["prediction"] == c |
91 | return dataframe[mask]["label"].value_counts(sort=False).idxmax() | 92 | return dataframe[mask]["label"].value_counts(sort=False).idxmax() |
92 | 93 | ||
93 | cluster_names = list(map(find_cluster_name_fn, range(num_classes))) | 94 | cluster_names = list(map(find_cluster_name_fn, range(num_classes))) |
94 | predicted_labels = le.transform( | 95 | predicted_labels = le.transform( |
95 | [cluster_names[pred] for pred in predictions]) | 96 | [cluster_names[pred] for pred in predictions]) |
96 | 97 | ||
97 | # F-measure | 98 | # F-measure |
98 | fscores = f1_score(labels, predicted_labels, average=None) | 99 | fscores = f1_score(labels, predicted_labels, average=None) |
99 | fscores_str = "\n".join(map(lambda i: "{0:25s}: {1:.4f}".format(le.classes_[i], fscores[i]), range(len(fscores)))) | 100 | fscores_str = "\n".join(map(lambda i: "{0:25s}: {1:.4f}".format(le.classes_[i], fscores[i]), range(len(fscores)))) |
100 | print(f"F1-scores for each classes:\n{fscores_str}") | 101 | print(f"F1-scores for each classes:\n{fscores_str}") |
101 | print(f"Global score : {np.mean(fscores)}") | 102 | print(f"Global score : {np.mean(fscores)}") |
102 | with open(os.path.join(args.outdir, args.prefix + "eval_clustering.log"), "w") as fd: | 103 | with open(os.path.join(args.outdir, args.prefix + "eval_clustering.log"), "w") as fd: |
103 | print(f"F1-scores for each classes:\n{fscores_str}", file=fd) | 104 | print(f"F1-scores for each classes:\n{fscores_str}", file=fd) |
104 | print(f"Global score : {np.mean(fscores)}", file=fd) | 105 | print(f"Global score : {np.mean(fscores)}", file=fd) |
105 | 106 | ||
106 | # Process t-SNE and plot | 107 | # Process t-SNE and plot |
107 | tsne_estimator = TSNE() | 108 | tsne_estimator = TSNE() |
108 | embeddings = tsne_estimator.fit_transform(feats) | 109 | embeddings = tsne_estimator.fit_transform(feats) |
109 | print("t-SNE: processed {0} iterations - KL_divergence={1:.4f}".format( | 110 | print("t-SNE: processed {0} iterations - KL_divergence={1:.4f}".format( |
110 | tsne_estimator.n_iter_, tsne_estimator.kl_divergence_)) | 111 | tsne_estimator.n_iter_, tsne_estimator.kl_divergence_)) |
111 | 112 | ||
112 | fig, [axe1, axe2] = plt.subplots(1, 2, figsize=(10, 5)) | 113 | fig, [axe1, axe2] = plt.subplots(1, 2, figsize=(10, 5)) |
113 | for c, name in enumerate(le.classes_): | 114 | for c, name in enumerate(le.classes_): |
114 | c_mask = np.where(labels == c) | 115 | c_mask = np.where(labels == c) |
115 | axe1.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) | 116 | axe1.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) |
116 | 117 | ||
117 | try: | 118 | try: |
118 | id_cluster = cluster_names.index(name) | 119 | id_cluster = cluster_names.index(name) |
119 | except ValueError: | 120 | except ValueError: |
120 | print("WARNING: no cluster found for {}".format(name)) | 121 | print("WARNING: no cluster found for {}".format(name)) |
121 | continue | 122 | continue |
122 | c_mask = np.where(predictions == id_cluster) | 123 | c_mask = np.where(predictions == id_cluster) |
123 | axe2.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) | 124 | axe2.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) |
124 | 125 | ||
125 | axe1.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) | 126 | axe1.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) |
126 | axe1.set_title("true labels") | 127 | axe1.set_title("true labels") |
127 | axe2.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) | 128 | axe2.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) |
128 | axe2.set_title("predicted cluster label") | 129 | axe2.set_title("predicted cluster label") |
129 | 130 | ||
130 | plt.suptitle("Kmeans Clustering") | 131 | plt.suptitle("Kmeans Clustering") |
131 | 132 | ||
132 | loc = os.path.join( | 133 | loc = os.path.join( |
133 | args.outdir, | 134 | args.outdir, |
134 | args.prefix + "kmeans.pdf" | 135 | args.prefix + "kmeans.pdf" |
135 | ) | 136 | ) |
136 | plt.savefig(loc, bbox_inches="tight") | 137 | plt.savefig(loc, bbox_inches="tight") |
137 | plt.close() | 138 | plt.close() |
138 | 139 | ||
139 | print("INFO: figure saved at {}".format(loc)) | 140 | print("INFO: figure saved at {}".format(loc)) |
140 | 141 | ||
141 | end = time.time() | 142 | end = time.time() |
142 | print("program ended in {0:.2f} seconds".format(end-start)) | 143 | print("program ended in {0:.2f} seconds".format(end-start)) |