Commit adbca3b1ce8ad1cd5bc482b687715ee5e3b2d3d8

Authored by Mathias
1 parent 85d6f0944e
Exists in master

Save the kmeans model

Showing 1 changed file with 4 additions and 0 deletions Inline Diff

scripts/evaluations/clustering.py
1 ''' 1 '''
2 This script allows the user to evaluate a classification system on new labels using clustering methods. 2 This script allows the user to evaluate a classification system on new labels using clustering methods.
3 The algorithms are applied on the given latent space (embedding). 3 The algorithms are applied on the given latent space (embedding).
4 ''' 4 '''
5 import argparse 5 import argparse
6 import numpy as np 6 import numpy as np
7 import pandas as pd 7 import pandas as pd
8 import os 8 import os
9 import time 9 import time
10 import pickle
10 from sklearn.preprocessing import LabelEncoder 11 from sklearn.preprocessing import LabelEncoder
11 from sklearn.metrics.pairwise import pairwise_distances 12 from sklearn.metrics.pairwise import pairwise_distances
12 from sklearn.metrics import f1_score 13 from sklearn.metrics import f1_score
13 from sklearn.cluster import KMeans 14 from sklearn.cluster import KMeans
14 from sklearn.manifold import TSNE 15 from sklearn.manifold import TSNE
15 import matplotlib.pyplot as plt 16 import matplotlib.pyplot as plt
16 17
17 from volia.data_io import read_features,read_lst 18 from volia.data_io import read_features,read_lst
18 19
19 if __name__ == "__main__": 20 if __name__ == "__main__":
20 # Argparse 21 # Argparse
21 parser = argparse.ArgumentParser("Compute clustering on a latent space") 22 parser = argparse.ArgumentParser("Compute clustering on a latent space")
22 parser.add_argument("features") 23 parser.add_argument("features")
23 parser.add_argument("utt2", 24 parser.add_argument("utt2",
24 type=str, 25 type=str,
25 help="file with [utt] [value]") 26 help="file with [utt] [value]")
26 parser.add_argument("--idsfrom", 27 parser.add_argument("--idsfrom",
27 type=str, 28 type=str,
28 default="utt2", 29 default="utt2",
29 choices=[ 30 choices=[
30 "features", 31 "features",
31 "utt2" 32 "utt2"
32 ], 33 ],
33 help="from features or from utt2?") 34 help="from features or from utt2?")
34 parser.add_argument("--prefix", 35 parser.add_argument("--prefix",
35 default="", 36 default="",
36 type=str, 37 type=str,
37 help="prefix of saved files") 38 help="prefix of saved files")
38 parser.add_argument("--outdir", 39 parser.add_argument("--outdir",
39 default=None, 40 default=None,
40 type=str, 41 type=str,
41 help="Output directory") 42 help="Output directory")
42 43
43 args = parser.parse_args() 44 args = parser.parse_args()
44 45
45 assert args.outdir 46 assert args.outdir
46 47
47 start = time.time() 48 start = time.time()
48 49
49 # Load features and utt2 50 # Load features and utt2
50 features = read_features(args.features) 51 features = read_features(args.features)
51 utt2 = read_lst(args.utt2) 52 utt2 = read_lst(args.utt2)
52 53
53 # Take id list 54 # Take id list
54 if args.idsfrom == "features": 55 if args.idsfrom == "features":
55 ids = list(features.keys()) 56 ids = list(features.keys())
56 elif args.idsfrom == "utt2": 57 elif args.idsfrom == "utt2":
57 ids = list(utt2.keys()) 58 ids = list(utt2.keys())
58 else: 59 else:
59 print(f"idsfrom is not good: {args.idsfrom}") 60 print(f"idsfrom is not good: {args.idsfrom}")
60 exit(1) 61 exit(1)
61 62
62 feats = np.vstack([ features[id_] for id_ in ids ]) 63 feats = np.vstack([ features[id_] for id_ in ids ])
63 classes = [ utt2[id_] for id_ in ids ] 64 classes = [ utt2[id_] for id_ in ids ]
64 65
65 # Encode labels 66 # Encode labels
66 le = LabelEncoder() 67 le = LabelEncoder()
67 labels = le.fit_transform(classes) 68 labels = le.fit_transform(classes)
68 num_classes = len(le.classes_) 69 num_classes = len(le.classes_)
69 70
70 # Compute KMEANS clustering on data 71 # Compute KMEANS clustering on data
71 estimator = KMeans( 72 estimator = KMeans(
72 n_clusters=num_classes, 73 n_clusters=num_classes,
73 n_init=100, 74 n_init=100,
74 tol=10-6, 75 tol=10-6,
75 algorithm="elkan" 76 algorithm="elkan"
76 ) 77 )
77 estimator.fit(feats) 78 estimator.fit(feats)
78 print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") 79 print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}")
79 80
81 with open(os.path.join(args.outdir, "kmeans.pkl"), "wb") as f:
82 pickle.dump(estimator, f)
83
80 # contains distance to each cluster for each sample 84 # contains distance to each cluster for each sample
81 dist_space = estimator.transform(feats) 85 dist_space = estimator.transform(feats)
82 predictions = np.argmin(dist_space, axis=1) 86 predictions = np.argmin(dist_space, axis=1)
83 87
84 # gives each cluster a name (considering most represented character) 88 # gives each cluster a name (considering most represented character)
85 dataframe = pd.DataFrame({ 89 dataframe = pd.DataFrame({
86 "label": pd.Series(list(map(lambda x: le.classes_[x], labels))), 90 "label": pd.Series(list(map(lambda x: le.classes_[x], labels))),
87 "prediction": pd.Series(predictions) 91 "prediction": pd.Series(predictions)
88 }) 92 })
89 93
90 def find_cluster_name_fn(c): 94 def find_cluster_name_fn(c):
91 mask = dataframe["prediction"] == c 95 mask = dataframe["prediction"] == c
92 return dataframe[mask]["label"].value_counts(sort=False).idxmax() 96 return dataframe[mask]["label"].value_counts(sort=False).idxmax()
93 97
94 cluster_names = list(map(find_cluster_name_fn, range(num_classes))) 98 cluster_names = list(map(find_cluster_name_fn, range(num_classes)))
95 predicted_labels = le.transform( 99 predicted_labels = le.transform(
96 [cluster_names[pred] for pred in predictions]) 100 [cluster_names[pred] for pred in predictions])
97 101
98 # F-measure 102 # F-measure
99 fscores = f1_score(labels, predicted_labels, average=None) 103 fscores = f1_score(labels, predicted_labels, average=None)
100 fscores_str = "\n".join(map(lambda i: "{0:25s}: {1:.4f}".format(le.classes_[i], fscores[i]), range(len(fscores)))) 104 fscores_str = "\n".join(map(lambda i: "{0:25s}: {1:.4f}".format(le.classes_[i], fscores[i]), range(len(fscores))))
101 print(f"F1-scores for each classes:\n{fscores_str}") 105 print(f"F1-scores for each classes:\n{fscores_str}")
102 print(f"Global score : {np.mean(fscores)}") 106 print(f"Global score : {np.mean(fscores)}")
103 with open(os.path.join(args.outdir, args.prefix + "eval_clustering.log"), "w") as fd: 107 with open(os.path.join(args.outdir, args.prefix + "eval_clustering.log"), "w") as fd:
104 print(f"F1-scores for each classes:\n{fscores_str}", file=fd) 108 print(f"F1-scores for each classes:\n{fscores_str}", file=fd)
105 print(f"Global score : {np.mean(fscores)}", file=fd) 109 print(f"Global score : {np.mean(fscores)}", file=fd)
106 110
107 # Process t-SNE and plot 111 # Process t-SNE and plot
108 tsne_estimator = TSNE() 112 tsne_estimator = TSNE()
109 embeddings = tsne_estimator.fit_transform(feats) 113 embeddings = tsne_estimator.fit_transform(feats)
110 print("t-SNE: processed {0} iterations - KL_divergence={1:.4f}".format( 114 print("t-SNE: processed {0} iterations - KL_divergence={1:.4f}".format(
111 tsne_estimator.n_iter_, tsne_estimator.kl_divergence_)) 115 tsne_estimator.n_iter_, tsne_estimator.kl_divergence_))
112 116
113 fig, [axe1, axe2] = plt.subplots(1, 2, figsize=(10, 5)) 117 fig, [axe1, axe2] = plt.subplots(1, 2, figsize=(10, 5))
114 for c, name in enumerate(le.classes_): 118 for c, name in enumerate(le.classes_):
115 c_mask = np.where(labels == c) 119 c_mask = np.where(labels == c)
116 axe1.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) 120 axe1.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None)
117 121
118 try: 122 try:
119 id_cluster = cluster_names.index(name) 123 id_cluster = cluster_names.index(name)
120 except ValueError: 124 except ValueError:
121 print("WARNING: no cluster found for {}".format(name)) 125 print("WARNING: no cluster found for {}".format(name))
122 continue 126 continue
123 c_mask = np.where(predictions == id_cluster) 127 c_mask = np.where(predictions == id_cluster)
124 axe2.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None) 128 axe2.scatter(embeddings[c_mask][:, 0], embeddings[c_mask][:, 1], label=name, alpha=0.2, edgecolors=None)
125 129
126 axe1.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) 130 axe1.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35))
127 axe1.set_title("true labels") 131 axe1.set_title("true labels")
128 axe2.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35)) 132 axe2.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35))
129 axe2.set_title("predicted cluster label") 133 axe2.set_title("predicted cluster label")
130 134
131 plt.suptitle("Kmeans Clustering") 135 plt.suptitle("Kmeans Clustering")
132 136
133 loc = os.path.join( 137 loc = os.path.join(
134 args.outdir, 138 args.outdir,
135 args.prefix + "kmeans.pdf" 139 args.prefix + "kmeans.pdf"
136 ) 140 )
137 plt.savefig(loc, bbox_inches="tight") 141 plt.savefig(loc, bbox_inches="tight")
138 plt.close() 142 plt.close()
139 143
140 print("INFO: figure saved at {}".format(loc)) 144 print("INFO: figure saved at {}".format(loc))
141 145
142 end = time.time() 146 end = time.time()
143 print("program ended in {0:.2f} seconds".format(end-start)) 147 print("program ended in {0:.2f} seconds".format(end-start))