Blame view

bin/plot-count-matrix.py 3.24 KB
b7530e269   Mathias Quillot   allow you to plot...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  '''
  This script aims to plot matrix count.
  '''
  import argparse
  import numpy as np
  from data import read_file, index_by_id
  from sklearn import preprocessing
  import matplotlib.pyplot as plt
  
  
  # TODO: Avoir la liste des personnages
  # TODO: liste des clusters
  parser = argparse.ArgumentParser(description="Plot count matrix")
  parser.add_argument("clustering", type=str,
                      help="clustering file")
  parser.add_argument("classlst", type=str,
                      help="List used for its classes.")
  parser.add_argument("lst", type=str,
                      help="list")
  parser.add_argument("--outfile", type=str, default="out.pdf",
                      help="output file path")
  
  args = parser.parse_args()
  CLUSTERING = args.clustering
  CLASS_LST = args.classlst
  LST = args.lst
  OUTFILE = args.outfile
  
  # -- READ FILES
  clustering = read_file(CLUSTERING)
  clustering_ind = index_by_id(clustering)
  
  class_lst = read_file(CLASS_LST)
  class_lst_ind = index_by_id(class_lst)
  
  lst = read_file(LST)
  
  # -- GET CLASSES AND CLUSTERS
  classes = np.asarray([class_lst_ind[x[0][0]][x[0][3]][0][1] for x in lst])
  clusters = np.asarray([clustering_ind[x[0][0]][x[0][3]][0][1] for x in lst])
  
  def generate_count_matrix(classes, clusters):
      '''
      Generate matrices for the given set
      Lines are clusters and columns are classes. 
      A cell is contains the number of character occurence 
      on a specific cluster.
      '''
  
      # Index Classes
      classe_unique = np.unique(classes)
      #all_classes = np.unique(np.concatenate((classe_unique)))
      all_classes = classe_unique
  
      # Label Encoder for classes
      le = preprocessing.LabelEncoder()
      le.fit(all_classes)
  
      # Index
      cluster_unique = np.unique(clusters)
  
      #all_clusters = np.unique(np.concatenate((cluster_unique)))
      all_clusters = cluster_unique
      # Create matrix lin(clust) col(class)
      counts_matrix = np.zeros((np.max(np.asarray(all_clusters, dtype=np.int)) + 1, len(all_classes)))
  
      for cluster in all_clusters:
  
          # Il faut d'abord extraire les classes présentes dans ce cluster
          cc = np.extract(np.asarray(clusters) == cluster, np.asarray(classes))
  
          cc_unique, cc_counts = np.unique(cc, return_counts=True)
          cc_ind = dict(zip(cc_unique, cc_counts))
  
          for class_ in all_classes:
              class_id = le.transform([class_])[0]
              if class_ in cc_ind:
                  counts_matrix[int(cluster)][int(class_id)] = cc_ind[class_]
      return (counts_matrix, all_classes, all_clusters)
  
  count_matrix, all_classes, all_clusters = generate_count_matrix(classes, clusters)
  
  fig, ax = plt.subplots()
  fig.set_size_inches(10, len(all_clusters) + 1)
  im = ax.imshow(count_matrix)
  
  ax.set_xticks(np.arange(len(all_classes)))
  ax.set_yticks(np.arange(len(all_clusters)))
  
  ax.set_xticklabels(all_classes)
87c56690a   Mathias Quillot   repair error from...
91
  ax.set_yticklabels(list(np.arange(len(all_clusters))))
b7530e269   Mathias Quillot   allow you to plot...
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
  
  fig.colorbar(im, ax=ax)
  
  plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
  
  # Loop over data dimensions and create text annotations.
  for i in range(count_matrix.shape[0]):
      for j in range(count_matrix.shape[1]):
          text = ax.text(j, i, int(count_matrix[i, j]),
                         ha="center", va="center", color="w")
  
  
  ax.set_title("Count Matrix")
  fig.tight_layout()
  plt.savefig(OUTFILE, bbox_inches='tight')