Commit b7530e26935afb3b75313d776f3795b5b4b2ecb7
1 parent
29644ae6c3
Exists in
master
allow you to plot count matrix. For each cluster, the number of element belongin…
…g to each class is plot into a matrix.
Showing 1 changed file with 106 additions and 0 deletions Inline Diff
bin/plot-count-matrix.py
File was created | 1 | ''' | |
2 | This script aims to plot matrix count. | ||
3 | ''' | ||
4 | import argparse | ||
5 | import numpy as np | ||
6 | from data import read_file, index_by_id | ||
7 | from sklearn import preprocessing | ||
8 | import matplotlib.pyplot as plt | ||
9 | |||
10 | |||
11 | # TODO: Avoir la liste des personnages | ||
12 | # TODO: liste des clusters | ||
13 | parser = argparse.ArgumentParser(description="Plot count matrix") | ||
14 | parser.add_argument("clustering", type=str, | ||
15 | help="clustering file") | ||
16 | parser.add_argument("classlst", type=str, | ||
17 | help="List used for its classes.") | ||
18 | parser.add_argument("lst", type=str, | ||
19 | help="list") | ||
20 | parser.add_argument("--outfile", type=str, default="out.pdf", | ||
21 | help="output file path") | ||
22 | |||
23 | args = parser.parse_args() | ||
24 | CLUSTERING = args.clustering | ||
25 | CLASS_LST = args.classlst | ||
26 | LST = args.lst | ||
27 | OUTFILE = args.outfile | ||
28 | |||
29 | # -- READ FILES | ||
30 | clustering = read_file(CLUSTERING) | ||
31 | clustering_ind = index_by_id(clustering) | ||
32 | |||
33 | class_lst = read_file(CLASS_LST) | ||
34 | class_lst_ind = index_by_id(class_lst) | ||
35 | |||
36 | lst = read_file(LST) | ||
37 | |||
38 | # -- GET CLASSES AND CLUSTERS | ||
39 | classes = np.asarray([class_lst_ind[x[0][0]][x[0][3]][0][1] for x in lst]) | ||
40 | clusters = np.asarray([clustering_ind[x[0][0]][x[0][3]][0][1] for x in lst]) | ||
41 | |||
42 | def generate_count_matrix(classes, clusters): | ||
43 | ''' | ||
44 | Generate matrices for the given set | ||
45 | Lines are clusters and columns are classes. | ||
46 | A cell is contains the number of character occurence | ||
47 | on a specific cluster. | ||
48 | ''' | ||
49 | |||
50 | # Index Classes | ||
51 | classe_unique = np.unique(classes) | ||
52 | #all_classes = np.unique(np.concatenate((classe_unique))) | ||
53 | all_classes = classe_unique | ||
54 | |||
55 | # Label Encoder for classes | ||
56 | le = preprocessing.LabelEncoder() | ||
57 | le.fit(all_classes) | ||
58 | |||
59 | # Index | ||
60 | cluster_unique = np.unique(clusters) | ||
61 | |||
62 | #all_clusters = np.unique(np.concatenate((cluster_unique))) | ||
63 | all_clusters = cluster_unique | ||
64 | # Create matrix lin(clust) col(class) | ||
65 | counts_matrix = np.zeros((np.max(np.asarray(all_clusters, dtype=np.int)) + 1, len(all_classes))) | ||
66 | |||
67 | for cluster in all_clusters: | ||
68 | |||
69 | # Il faut d'abord extraire les classes présentes dans ce cluster | ||
70 | cc = np.extract(np.asarray(clusters) == cluster, np.asarray(classes)) | ||
71 | |||
72 | cc_unique, cc_counts = np.unique(cc, return_counts=True) | ||
73 | cc_ind = dict(zip(cc_unique, cc_counts)) | ||
74 | |||
75 | for class_ in all_classes: | ||
76 | class_id = le.transform([class_])[0] | ||
77 | if class_ in cc_ind: | ||
78 | counts_matrix[int(cluster)][int(class_id)] = cc_ind[class_] | ||
79 | return (counts_matrix, all_classes, all_clusters) | ||
80 | |||
81 | count_matrix, all_classes, all_clusters = generate_count_matrix(classes, clusters) | ||
82 | |||
83 | fig, ax = plt.subplots() | ||
84 | fig.set_size_inches(10, len(all_clusters) + 1) | ||
85 | im = ax.imshow(count_matrix) | ||
86 | |||
87 | ax.set_xticks(np.arange(len(all_classes))) | ||
88 | ax.set_yticks(np.arange(len(all_clusters))) | ||
89 | |||
90 | ax.set_xticklabels(all_classes) | ||
91 | ax.set_yticklabels(all_clusters) | ||
92 | |||
93 | fig.colorbar(im, ax=ax) | ||
94 | |||
95 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") | ||
96 | |||
97 | # Loop over data dimensions and create text annotations. | ||
98 | for i in range(count_matrix.shape[0]): | ||
99 | for j in range(count_matrix.shape[1]): | ||
100 | text = ax.text(j, i, int(count_matrix[i, j]), | ||
101 | ha="center", va="center", color="w") | ||
102 | |||
103 | |||
104 | ax.set_title("Count Matrix") | ||
105 | fig.tight_layout() | ||
106 | plt.savefig(OUTFILE, bbox_inches='tight') | ||
107 |