Commit 87c56690afdae6225b8ec4697e8799c07ee30cc0

Authored by Mathias Quillot
1 parent 6da8f6ca75
Exists in master

repair error from plot

Showing 1 changed file with 1 additions and 1 deletions Inline Diff

bin/plot-count-matrix.py
1 ''' 1 '''
2 This script aims to plot matrix count. 2 This script aims to plot matrix count.
3 ''' 3 '''
4 import argparse 4 import argparse
5 import numpy as np 5 import numpy as np
6 from data import read_file, index_by_id 6 from data import read_file, index_by_id
7 from sklearn import preprocessing 7 from sklearn import preprocessing
8 import matplotlib.pyplot as plt 8 import matplotlib.pyplot as plt
9 9
10 10
11 # TODO: Avoir la liste des personnages 11 # TODO: Avoir la liste des personnages
12 # TODO: liste des clusters 12 # TODO: liste des clusters
13 parser = argparse.ArgumentParser(description="Plot count matrix") 13 parser = argparse.ArgumentParser(description="Plot count matrix")
14 parser.add_argument("clustering", type=str, 14 parser.add_argument("clustering", type=str,
15 help="clustering file") 15 help="clustering file")
16 parser.add_argument("classlst", type=str, 16 parser.add_argument("classlst", type=str,
17 help="List used for its classes.") 17 help="List used for its classes.")
18 parser.add_argument("lst", type=str, 18 parser.add_argument("lst", type=str,
19 help="list") 19 help="list")
20 parser.add_argument("--outfile", type=str, default="out.pdf", 20 parser.add_argument("--outfile", type=str, default="out.pdf",
21 help="output file path") 21 help="output file path")
22 22
23 args = parser.parse_args() 23 args = parser.parse_args()
24 CLUSTERING = args.clustering 24 CLUSTERING = args.clustering
25 CLASS_LST = args.classlst 25 CLASS_LST = args.classlst
26 LST = args.lst 26 LST = args.lst
27 OUTFILE = args.outfile 27 OUTFILE = args.outfile
28 28
29 # -- READ FILES 29 # -- READ FILES
30 clustering = read_file(CLUSTERING) 30 clustering = read_file(CLUSTERING)
31 clustering_ind = index_by_id(clustering) 31 clustering_ind = index_by_id(clustering)
32 32
33 class_lst = read_file(CLASS_LST) 33 class_lst = read_file(CLASS_LST)
34 class_lst_ind = index_by_id(class_lst) 34 class_lst_ind = index_by_id(class_lst)
35 35
36 lst = read_file(LST) 36 lst = read_file(LST)
37 37
38 # -- GET CLASSES AND CLUSTERS 38 # -- GET CLASSES AND CLUSTERS
39 classes = np.asarray([class_lst_ind[x[0][0]][x[0][3]][0][1] for x in lst]) 39 classes = np.asarray([class_lst_ind[x[0][0]][x[0][3]][0][1] for x in lst])
40 clusters = np.asarray([clustering_ind[x[0][0]][x[0][3]][0][1] for x in lst]) 40 clusters = np.asarray([clustering_ind[x[0][0]][x[0][3]][0][1] for x in lst])
41 41
42 def generate_count_matrix(classes, clusters): 42 def generate_count_matrix(classes, clusters):
43 ''' 43 '''
44 Generate matrices for the given set 44 Generate matrices for the given set
45 Lines are clusters and columns are classes. 45 Lines are clusters and columns are classes.
46 A cell is contains the number of character occurence 46 A cell is contains the number of character occurence
47 on a specific cluster. 47 on a specific cluster.
48 ''' 48 '''
49 49
50 # Index Classes 50 # Index Classes
51 classe_unique = np.unique(classes) 51 classe_unique = np.unique(classes)
52 #all_classes = np.unique(np.concatenate((classe_unique))) 52 #all_classes = np.unique(np.concatenate((classe_unique)))
53 all_classes = classe_unique 53 all_classes = classe_unique
54 54
55 # Label Encoder for classes 55 # Label Encoder for classes
56 le = preprocessing.LabelEncoder() 56 le = preprocessing.LabelEncoder()
57 le.fit(all_classes) 57 le.fit(all_classes)
58 58
59 # Index 59 # Index
60 cluster_unique = np.unique(clusters) 60 cluster_unique = np.unique(clusters)
61 61
62 #all_clusters = np.unique(np.concatenate((cluster_unique))) 62 #all_clusters = np.unique(np.concatenate((cluster_unique)))
63 all_clusters = cluster_unique 63 all_clusters = cluster_unique
64 # Create matrix lin(clust) col(class) 64 # Create matrix lin(clust) col(class)
65 counts_matrix = np.zeros((np.max(np.asarray(all_clusters, dtype=np.int)) + 1, len(all_classes))) 65 counts_matrix = np.zeros((np.max(np.asarray(all_clusters, dtype=np.int)) + 1, len(all_classes)))
66 66
67 for cluster in all_clusters: 67 for cluster in all_clusters:
68 68
69 # Il faut d'abord extraire les classes présentes dans ce cluster 69 # Il faut d'abord extraire les classes présentes dans ce cluster
70 cc = np.extract(np.asarray(clusters) == cluster, np.asarray(classes)) 70 cc = np.extract(np.asarray(clusters) == cluster, np.asarray(classes))
71 71
72 cc_unique, cc_counts = np.unique(cc, return_counts=True) 72 cc_unique, cc_counts = np.unique(cc, return_counts=True)
73 cc_ind = dict(zip(cc_unique, cc_counts)) 73 cc_ind = dict(zip(cc_unique, cc_counts))
74 74
75 for class_ in all_classes: 75 for class_ in all_classes:
76 class_id = le.transform([class_])[0] 76 class_id = le.transform([class_])[0]
77 if class_ in cc_ind: 77 if class_ in cc_ind:
78 counts_matrix[int(cluster)][int(class_id)] = cc_ind[class_] 78 counts_matrix[int(cluster)][int(class_id)] = cc_ind[class_]
79 return (counts_matrix, all_classes, all_clusters) 79 return (counts_matrix, all_classes, all_clusters)
80 80
81 count_matrix, all_classes, all_clusters = generate_count_matrix(classes, clusters) 81 count_matrix, all_classes, all_clusters = generate_count_matrix(classes, clusters)
82 82
83 fig, ax = plt.subplots() 83 fig, ax = plt.subplots()
84 fig.set_size_inches(10, len(all_clusters) + 1) 84 fig.set_size_inches(10, len(all_clusters) + 1)
85 im = ax.imshow(count_matrix) 85 im = ax.imshow(count_matrix)
86 86
87 ax.set_xticks(np.arange(len(all_classes))) 87 ax.set_xticks(np.arange(len(all_classes)))
88 ax.set_yticks(np.arange(len(all_clusters))) 88 ax.set_yticks(np.arange(len(all_clusters)))
89 89
90 ax.set_xticklabels(all_classes) 90 ax.set_xticklabels(all_classes)
91 ax.set_yticklabels(all_clusters) 91 ax.set_yticklabels(list(np.arange(len(all_clusters))))
92 92
93 fig.colorbar(im, ax=ax) 93 fig.colorbar(im, ax=ax)
94 94
95 plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 95 plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
96 96
97 # Loop over data dimensions and create text annotations. 97 # Loop over data dimensions and create text annotations.
98 for i in range(count_matrix.shape[0]): 98 for i in range(count_matrix.shape[0]):
99 for j in range(count_matrix.shape[1]): 99 for j in range(count_matrix.shape[1]):
100 text = ax.text(j, i, int(count_matrix[i, j]), 100 text = ax.text(j, i, int(count_matrix[i, j]),
101 ha="center", va="center", color="w") 101 ha="center", va="center", color="w")
102 102
103 103
104 ax.set_title("Count Matrix") 104 ax.set_title("Count Matrix")
105 fig.tight_layout() 105 fig.tight_layout()
106 plt.savefig(OUTFILE, bbox_inches='tight') 106 plt.savefig(OUTFILE, bbox_inches='tight')
107 107