Commit 87c56690afdae6225b8ec4697e8799c07ee30cc0
1 parent
6da8f6ca75
Exists in
master
repair error from plot
Showing 1 changed file with 1 additions and 1 deletions Inline Diff
bin/plot-count-matrix.py
1 | ''' | 1 | ''' |
2 | This script aims to plot matrix count. | 2 | This script aims to plot matrix count. |
3 | ''' | 3 | ''' |
4 | import argparse | 4 | import argparse |
5 | import numpy as np | 5 | import numpy as np |
6 | from data import read_file, index_by_id | 6 | from data import read_file, index_by_id |
7 | from sklearn import preprocessing | 7 | from sklearn import preprocessing |
8 | import matplotlib.pyplot as plt | 8 | import matplotlib.pyplot as plt |
9 | 9 | ||
10 | 10 | ||
11 | # TODO: Avoir la liste des personnages | 11 | # TODO: Avoir la liste des personnages |
12 | # TODO: liste des clusters | 12 | # TODO: liste des clusters |
13 | parser = argparse.ArgumentParser(description="Plot count matrix") | 13 | parser = argparse.ArgumentParser(description="Plot count matrix") |
14 | parser.add_argument("clustering", type=str, | 14 | parser.add_argument("clustering", type=str, |
15 | help="clustering file") | 15 | help="clustering file") |
16 | parser.add_argument("classlst", type=str, | 16 | parser.add_argument("classlst", type=str, |
17 | help="List used for its classes.") | 17 | help="List used for its classes.") |
18 | parser.add_argument("lst", type=str, | 18 | parser.add_argument("lst", type=str, |
19 | help="list") | 19 | help="list") |
20 | parser.add_argument("--outfile", type=str, default="out.pdf", | 20 | parser.add_argument("--outfile", type=str, default="out.pdf", |
21 | help="output file path") | 21 | help="output file path") |
22 | 22 | ||
23 | args = parser.parse_args() | 23 | args = parser.parse_args() |
24 | CLUSTERING = args.clustering | 24 | CLUSTERING = args.clustering |
25 | CLASS_LST = args.classlst | 25 | CLASS_LST = args.classlst |
26 | LST = args.lst | 26 | LST = args.lst |
27 | OUTFILE = args.outfile | 27 | OUTFILE = args.outfile |
28 | 28 | ||
29 | # -- READ FILES | 29 | # -- READ FILES |
30 | clustering = read_file(CLUSTERING) | 30 | clustering = read_file(CLUSTERING) |
31 | clustering_ind = index_by_id(clustering) | 31 | clustering_ind = index_by_id(clustering) |
32 | 32 | ||
33 | class_lst = read_file(CLASS_LST) | 33 | class_lst = read_file(CLASS_LST) |
34 | class_lst_ind = index_by_id(class_lst) | 34 | class_lst_ind = index_by_id(class_lst) |
35 | 35 | ||
36 | lst = read_file(LST) | 36 | lst = read_file(LST) |
37 | 37 | ||
38 | # -- GET CLASSES AND CLUSTERS | 38 | # -- GET CLASSES AND CLUSTERS |
39 | classes = np.asarray([class_lst_ind[x[0][0]][x[0][3]][0][1] for x in lst]) | 39 | classes = np.asarray([class_lst_ind[x[0][0]][x[0][3]][0][1] for x in lst]) |
40 | clusters = np.asarray([clustering_ind[x[0][0]][x[0][3]][0][1] for x in lst]) | 40 | clusters = np.asarray([clustering_ind[x[0][0]][x[0][3]][0][1] for x in lst]) |
41 | 41 | ||
42 | def generate_count_matrix(classes, clusters): | 42 | def generate_count_matrix(classes, clusters): |
43 | ''' | 43 | ''' |
44 | Generate matrices for the given set | 44 | Generate matrices for the given set |
45 | Lines are clusters and columns are classes. | 45 | Lines are clusters and columns are classes. |
46 | A cell is contains the number of character occurence | 46 | A cell is contains the number of character occurence |
47 | on a specific cluster. | 47 | on a specific cluster. |
48 | ''' | 48 | ''' |
49 | 49 | ||
50 | # Index Classes | 50 | # Index Classes |
51 | classe_unique = np.unique(classes) | 51 | classe_unique = np.unique(classes) |
52 | #all_classes = np.unique(np.concatenate((classe_unique))) | 52 | #all_classes = np.unique(np.concatenate((classe_unique))) |
53 | all_classes = classe_unique | 53 | all_classes = classe_unique |
54 | 54 | ||
55 | # Label Encoder for classes | 55 | # Label Encoder for classes |
56 | le = preprocessing.LabelEncoder() | 56 | le = preprocessing.LabelEncoder() |
57 | le.fit(all_classes) | 57 | le.fit(all_classes) |
58 | 58 | ||
59 | # Index | 59 | # Index |
60 | cluster_unique = np.unique(clusters) | 60 | cluster_unique = np.unique(clusters) |
61 | 61 | ||
62 | #all_clusters = np.unique(np.concatenate((cluster_unique))) | 62 | #all_clusters = np.unique(np.concatenate((cluster_unique))) |
63 | all_clusters = cluster_unique | 63 | all_clusters = cluster_unique |
64 | # Create matrix lin(clust) col(class) | 64 | # Create matrix lin(clust) col(class) |
65 | counts_matrix = np.zeros((np.max(np.asarray(all_clusters, dtype=np.int)) + 1, len(all_classes))) | 65 | counts_matrix = np.zeros((np.max(np.asarray(all_clusters, dtype=np.int)) + 1, len(all_classes))) |
66 | 66 | ||
67 | for cluster in all_clusters: | 67 | for cluster in all_clusters: |
68 | 68 | ||
69 | # Il faut d'abord extraire les classes présentes dans ce cluster | 69 | # Il faut d'abord extraire les classes présentes dans ce cluster |
70 | cc = np.extract(np.asarray(clusters) == cluster, np.asarray(classes)) | 70 | cc = np.extract(np.asarray(clusters) == cluster, np.asarray(classes)) |
71 | 71 | ||
72 | cc_unique, cc_counts = np.unique(cc, return_counts=True) | 72 | cc_unique, cc_counts = np.unique(cc, return_counts=True) |
73 | cc_ind = dict(zip(cc_unique, cc_counts)) | 73 | cc_ind = dict(zip(cc_unique, cc_counts)) |
74 | 74 | ||
75 | for class_ in all_classes: | 75 | for class_ in all_classes: |
76 | class_id = le.transform([class_])[0] | 76 | class_id = le.transform([class_])[0] |
77 | if class_ in cc_ind: | 77 | if class_ in cc_ind: |
78 | counts_matrix[int(cluster)][int(class_id)] = cc_ind[class_] | 78 | counts_matrix[int(cluster)][int(class_id)] = cc_ind[class_] |
79 | return (counts_matrix, all_classes, all_clusters) | 79 | return (counts_matrix, all_classes, all_clusters) |
80 | 80 | ||
81 | count_matrix, all_classes, all_clusters = generate_count_matrix(classes, clusters) | 81 | count_matrix, all_classes, all_clusters = generate_count_matrix(classes, clusters) |
82 | 82 | ||
83 | fig, ax = plt.subplots() | 83 | fig, ax = plt.subplots() |
84 | fig.set_size_inches(10, len(all_clusters) + 1) | 84 | fig.set_size_inches(10, len(all_clusters) + 1) |
85 | im = ax.imshow(count_matrix) | 85 | im = ax.imshow(count_matrix) |
86 | 86 | ||
87 | ax.set_xticks(np.arange(len(all_classes))) | 87 | ax.set_xticks(np.arange(len(all_classes))) |
88 | ax.set_yticks(np.arange(len(all_clusters))) | 88 | ax.set_yticks(np.arange(len(all_clusters))) |
89 | 89 | ||
90 | ax.set_xticklabels(all_classes) | 90 | ax.set_xticklabels(all_classes) |
91 | ax.set_yticklabels(all_clusters) | 91 | ax.set_yticklabels(list(np.arange(len(all_clusters)))) |
92 | 92 | ||
93 | fig.colorbar(im, ax=ax) | 93 | fig.colorbar(im, ax=ax) |
94 | 94 | ||
95 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") | 95 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") |
96 | 96 | ||
97 | # Loop over data dimensions and create text annotations. | 97 | # Loop over data dimensions and create text annotations. |
98 | for i in range(count_matrix.shape[0]): | 98 | for i in range(count_matrix.shape[0]): |
99 | for j in range(count_matrix.shape[1]): | 99 | for j in range(count_matrix.shape[1]): |
100 | text = ax.text(j, i, int(count_matrix[i, j]), | 100 | text = ax.text(j, i, int(count_matrix[i, j]), |
101 | ha="center", va="center", color="w") | 101 | ha="center", va="center", color="w") |
102 | 102 | ||
103 | 103 | ||
104 | ax.set_title("Count Matrix") | 104 | ax.set_title("Count Matrix") |
105 | fig.tight_layout() | 105 | fig.tight_layout() |
106 | plt.savefig(OUTFILE, bbox_inches='tight') | 106 | plt.savefig(OUTFILE, bbox_inches='tight') |
107 | 107 |