Commit adbca3b1ce8ad1cd5bc482b687715ee5e3b2d3d8
1 parent
85d6f0944e
Exists in
master
Save the kmeans model
Showing 1 changed file with 4 additions and 0 deletions Inline Diff
scripts/evaluations/clustering.py
1 | ''' | 1 | ''' |
2 | This script allows the user to evaluate a classification system on new labels using clustering methods. | 2 | This script allows the user to evaluate a classification system on new labels using clustering methods. |
3 | The algorithms are applied on the given latent space (embedding). | 3 | The algorithms are applied on the given latent space (embedding). |
4 | ''' | 4 | ''' |
5 | import argparse | 5 | import argparse |
6 | import numpy as np | 6 | import numpy as np |
7 | import pandas as pd | 7 | import pandas as pd |
8 | import os | 8 | import os |
9 | import time | 9 | import time |
10 | import pickle | ||
10 | from sklearn.preprocessing import LabelEncoder | 11 | from sklearn.preprocessing import LabelEncoder |
11 | from sklearn.metrics.pairwise import pairwise_distances | 12 | from sklearn.metrics.pairwise import pairwise_distances |
12 | from sklearn.metrics import f1_score | 13 | from sklearn.metrics import f1_score |
13 | from sklearn.cluster import KMeans | 14 | from sklearn.cluster import KMeans |
14 | from sklearn.manifold import TSNE | 15 | from sklearn.manifold import TSNE |
15 | import matplotlib.pyplot as plt | 16 | import matplotlib.pyplot as plt |
16 | 17 | ||
17 | from volia.data_io import read_features,read_lst | 18 | from volia.data_io import read_features,read_lst |
18 | 19 | ||
19 | if __name__ == "__main__": | 20 | if __name__ == "__main__": |
20 | # Argparse | 21 | # Argparse |
21 | parser = argparse.ArgumentParser("Compute clustering on a latent space") | 22 | parser = argparse.ArgumentParser("Compute clustering on a latent space") |
22 | parser.add_argument("features") | 23 | parser.add_argument("features") |
23 | parser.add_argument("utt2", | 24 | parser.add_argument("utt2", |
24 | type=str, | 25 | type=str, |
25 | help="file with [utt] [value]") | 26 | help="file with [utt] [value]") |
26 | parser.add_argument("--idsfrom", | 27 | parser.add_argument("--idsfrom", |
27 | type=str, | 28 | type=str, |
28 | default="utt2", | 29 | default="utt2", |
29 | choices=[ | 30 | choices=[ |
30 | "features", | 31 | "features", |
31 | "utt2" | 32 | "utt2" |
32 | ], | 33 | ], |
33 | help="from features or from utt2?") | 34 | help="from features or from utt2?") |
34 | parser.add_argument("--prefix", | 35 | parser.add_argument("--prefix", |
35 | default="", | 36 | default="", |
36 | type=str, | 37 | type=str, |
37 | help="prefix of saved files") | 38 | help="prefix of saved files") |
38 | parser.add_argument("--outdir", | 39 | parser.add_argument("--outdir", |
39 | default=None, | 40 | default=None, |
40 | type=str, | 41 | type=str, |
41 | help="Output directory") | 42 | help="Output directory") |
42 | 43 | ||
43 | args = parser.parse_args() | 44 | args = parser.parse_args() |
44 | 45 | ||
45 | assert args.outdir | 46 | assert args.outdir |
46 | 47 | ||
47 | start = time.time() | 48 | start = time.time() |
48 | 49 | ||
49 | # Load features and utt2 | 50 | # Load features and utt2 |
50 | features = read_features(args.features) | 51 | features = read_features(args.features) |
51 | utt2 = read_lst(args.utt2) | 52 | utt2 = read_lst(args.utt2) |
52 | 53 | ||
53 | # Take id list | 54 | # Take id list |
54 | if args.idsfrom == "features": | 55 | if args.idsfrom == "features": |
55 | ids = list(features.keys()) | 56 | ids = list(features.keys()) |
56 | elif args.idsfrom == "utt2": | 57 | elif args.idsfrom == "utt2": |
57 | ids = list(utt2.keys()) | 58 | ids = list(utt2.keys()) |
58 | else: | 59 | else: |
59 | print(f"idsfrom is not good: {args.idsfrom}") | 60 | print(f"idsfrom is not good: {args.idsfrom}") |
60 | exit(1) | 61 | exit(1) |
61 | 62 | ||
62 | feats = np.vstack([ features[id_] for id_ in ids ]) | 63 | feats = np.vstack([ features[id_] for id_ in ids ]) |
63 | classes = [ utt2[id_] for id_ in ids ] | 64 | classes = [ utt2[id_] for id_ in ids ] |
64 | 65 | ||
65 | # Encode labels | 66 | # Encode labels |
66 | le = LabelEncoder() | 67 | le = LabelEncoder() |
67 | labels = le.fit_transform(classes) | 68 | labels = le.fit_transform(classes) |
68 | num_classes = len(le.classes_) | 69 | num_classes = len(le.classes_) |
69 | 70 | ||
70 | # Compute KMEANS clustering on data | 71 | # Compute KMEANS clustering on data |
71 | estimator = KMeans( | 72 | estimator = KMeans( |
72 | n_clusters=num_classes, | 73 | n_clusters=num_classes, |
73 | n_init=100, | 74 | n_init=100, |
74 | tol=10-6, | 75 | tol=10-6, |
75 | algorithm="elkan" | 76 | algorithm="elkan" |
76 | ) | 77 | ) |
77 | estimator.fit(feats) | 78 | estimator.fit(feats) |
78 | print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") | 79 | print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") |
79 | 80 | ||
81 | with open(os.path.join(args.outdir, "kmeans.pkl"), "wb") as f: | ||
82 | pickle.dump(estimator, f) | ||
83 | |||
80 | # contains distance to each cluster for each sample | 84 | # contains distance to each cluster for each sample |
81 | dist_space = estimator.transform(feats) | 85 | dist_space = estimator.transform(feats) |
82 | predictions = np.argmin(dist_space, axis=1) | 86 | predictions = np.argmin(dist_space, axis=1) |
83 | 87 | ||
84 | # gives each cluster a name (considering most represented character) | 88 | # gives each cluster a name (considering most represented character) |
85 | dataframe = pd.DataFrame({ | 89 | dataframe = pd.DataFrame({ |
86 | "label": pd.Series(list(map(lambda x: le.classes_[x], labels))), | 90 | "label": pd.Series(list(map(lambda x: le.classes_[x], labels))), |
87 | "prediction": pd.Series(predictions) | 91 | "prediction": pd.Series(predictions) |
88 | }) | 92 | }) |
89 | 93 | ||
90 | def find_cluster_name_fn(c): | 94 | def find_cluster_name_fn(c): |
91 | mask = dataframe["prediction"] == c | 95 | mask = dataframe["prediction"] == c |
92 | return dataframe[mask]["label"].value_counts(sort=False).idxmax() | 96 | return dataframe[mask]["label"].value_counts(sort=False).idxmax() |
93 | 97 | ||
94 | cluster_names = list(map(find_cluster_name_fn, range(num_classes))) | 98 | cluster_names = list(map(find_cluster_name_fn, range(num_classes))) |
95 | predicted_labels = le.transform( | 99 | predicted_labels = le.transform( |
96 | [cluster_names[pred] for pred in predictions]) | 100 | [cluster_names[pred] for pred in predictions]) |
97 | 101 | ||
98 | # F-measure | 102 | # F-measure |
99 | fscores = f1_score(labels, predicted_labels, average=None) | 103 | fscores = f1_score(labels, predicted_labels, average=None) |
100 | fscores_str = "\n".join(map(lambda i: "{0:25s}: {1:.4f}".format(le.classes_[i], fscores[i]), range(len(fscores)))) | 104 | fscores_str = "\n".join(map(lambda i: "{0:25s}: {1:.4f}".format(le.classes_[i], fscores[i]), range(len(fscores)))) |
101 | print(f"F1-scores for each classes:\n{fscores_str}") | 105 | print(f"F1-scores for each classes:\n{fscores_str}") |
102 | print(f"Global score : {np.mean(fscores)}") | 106 | print(f"Global score : {np.mean(fscores)}") |
103 | with open(os.path.join(args.outdir, args.prefix + "eval_clustering.log"), "w") as fd: | 107 | with open(os.path.join(args.outdir, args.prefix + "eval_clustering.log"), "w") as fd: |
104 | print(f"F1-scores for each classes:\n{fscores_str}", file=fd) | 108 | print(f"F1-scores for each classes:\n{fscores_str}", file=fd) |
105 | print(f"Global score : {np.mean(fscores)}", file=fd) | 109 | print(f"Global score : {np.mean(fscores)}", file=fd) |
106 | 110 | ||
107 | # Process t-SNE and plot | 111 | # Process t-SNE and plot |
108 | tsne_estimator = TSNE() | 112 | tsne_estimator = TSNE() |
109 | embeddings = tsne_estimator.fit_transform(feats) | 113 | embeddings = tsne_estimator.fit_transform(feats) |
110 | print("t-SNE: processed {0} iterations - KL_divergence={1:.4f}".format( | 114 | print("t-SNE: processed {0} iterations - KL_divergence={1:.4f}".format( |
111 | tsne_estimator.n_iter_, tsne_estimator.kl_divergence_)) | 115 | tsne_estimator.n_iter_, tsne_estimator.kl_divergence_)) |
112 | 116 | ||
113 | fig, [axe1, axe2] = plt.subplots(1, 2, figsize=(10, 5)) | 117 | fig, [axe1, axe2] = plt.subplots(1, 2, figsize=(10, 5)) |
114 | for c, name in enumerate(le.classes_): | 118 | for c, name in enumerate(le.classes_): |
115 | c_mask = np.where(labels == c) | 119 | c_mask = np.where(labels == c) |
116 | axe1.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) | 120 | axe1.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) |
117 | 121 | ||
118 | try: | 122 | try: |
119 | id_cluster = cluster_names.index(name) | 123 | id_cluster = cluster_names.index(name) |
120 | except ValueError: | 124 | except ValueError: |
121 | print("WARNING: no cluster found for {}".format(name)) | 125 | print("WARNING: no cluster found for {}".format(name)) |
122 | continue | 126 | continue |
123 | c_mask = np.where(predictions == id_cluster) | 127 | c_mask = np.where(predictions == id_cluster) |
124 | axe2.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) | 128 | axe2.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) |
125 | 129 | ||
126 | axe1.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) | 130 | axe1.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) |
127 | axe1.set_title("true labels") | 131 | axe1.set_title("true labels") |
128 | axe2.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) | 132 | axe2.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) |
129 | axe2.set_title("predicted cluster label") | 133 | axe2.set_title("predicted cluster label") |
130 | 134 | ||
131 | plt.suptitle("Kmeans Clustering") | 135 | plt.suptitle("Kmeans Clustering") |
132 | 136 | ||
133 | loc = os.path.join( | 137 | loc = os.path.join( |
134 | args.outdir, | 138 | args.outdir, |
135 | args.prefix + "kmeans.pdf" | 139 | args.prefix + "kmeans.pdf" |
136 | ) | 140 | ) |
137 | plt.savefig(loc, bbox_inches="tight") | 141 | plt.savefig(loc, bbox_inches="tight") |
138 | plt.close() | 142 | plt.close() |
139 | 143 | ||
140 | print("INFO: figure saved at {}".format(loc)) | 144 | print("INFO: figure saved at {}".format(loc)) |
141 | 145 | ||
142 | end = time.time() | 146 | end = time.time() |
143 | print("program ended in {0:.2f} seconds".format(end-start)) | 147 | print("program ended in {0:.2f} seconds".format(end-start)) |