Commit b7530e26935afb3b75313d776f3795b5b4b2ecb7

Authored by Mathias Quillot
1 parent 29644ae6c3
Exists in master

allow you to plot count matrix. For each cluster, the number of element belongin…

…g to each class is plot into a matrix.

Showing 1 changed file with 106 additions and 0 deletions Side-by-side Diff

bin/plot-count-matrix.py
  1 +'''
  2 +This script aims to plot matrix count.
  3 +'''
  4 +import argparse
  5 +import numpy as np
  6 +from data import read_file, index_by_id
  7 +from sklearn import preprocessing
  8 +import matplotlib.pyplot as plt
  9 +
  10 +
  11 +# TODO: Avoir la liste des personnages
  12 +# TODO: liste des clusters
  13 +parser = argparse.ArgumentParser(description="Plot count matrix")
  14 +parser.add_argument("clustering", type=str,
  15 + help="clustering file")
  16 +parser.add_argument("classlst", type=str,
  17 + help="List used for its classes.")
  18 +parser.add_argument("lst", type=str,
  19 + help="list")
  20 +parser.add_argument("--outfile", type=str, default="out.pdf",
  21 + help="output file path")
  22 +
  23 +args = parser.parse_args()
  24 +CLUSTERING = args.clustering
  25 +CLASS_LST = args.classlst
  26 +LST = args.lst
  27 +OUTFILE = args.outfile
  28 +
  29 +# -- READ FILES
  30 +clustering = read_file(CLUSTERING)
  31 +clustering_ind = index_by_id(clustering)
  32 +
  33 +class_lst = read_file(CLASS_LST)
  34 +class_lst_ind = index_by_id(class_lst)
  35 +
  36 +lst = read_file(LST)
  37 +
  38 +# -- GET CLASSES AND CLUSTERS
  39 +classes = np.asarray([class_lst_ind[x[0][0]][x[0][3]][0][1] for x in lst])
  40 +clusters = np.asarray([clustering_ind[x[0][0]][x[0][3]][0][1] for x in lst])
  41 +
  42 +def generate_count_matrix(classes, clusters):
  43 + '''
  44 + Generate matrices for the given set
  45 + Lines are clusters and columns are classes.
  46 + A cell is contains the number of character occurence
  47 + on a specific cluster.
  48 + '''
  49 +
  50 + # Index Classes
  51 + classe_unique = np.unique(classes)
  52 + #all_classes = np.unique(np.concatenate((classe_unique)))
  53 + all_classes = classe_unique
  54 +
  55 + # Label Encoder for classes
  56 + le = preprocessing.LabelEncoder()
  57 + le.fit(all_classes)
  58 +
  59 + # Index
  60 + cluster_unique = np.unique(clusters)
  61 +
  62 + #all_clusters = np.unique(np.concatenate((cluster_unique)))
  63 + all_clusters = cluster_unique
  64 + # Create matrix lin(clust) col(class)
  65 + counts_matrix = np.zeros((np.max(np.asarray(all_clusters, dtype=np.int)) + 1, len(all_classes)))
  66 +
  67 + for cluster in all_clusters:
  68 +
  69 + # Il faut d'abord extraire les classes présentes dans ce cluster
  70 + cc = np.extract(np.asarray(clusters) == cluster, np.asarray(classes))
  71 +
  72 + cc_unique, cc_counts = np.unique(cc, return_counts=True)
  73 + cc_ind = dict(zip(cc_unique, cc_counts))
  74 +
  75 + for class_ in all_classes:
  76 + class_id = le.transform([class_])[0]
  77 + if class_ in cc_ind:
  78 + counts_matrix[int(cluster)][int(class_id)] = cc_ind[class_]
  79 + return (counts_matrix, all_classes, all_clusters)
  80 +
  81 +count_matrix, all_classes, all_clusters = generate_count_matrix(classes, clusters)
  82 +
  83 +fig, ax = plt.subplots()
  84 +fig.set_size_inches(10, len(all_clusters) + 1)
  85 +im = ax.imshow(count_matrix)
  86 +
  87 +ax.set_xticks(np.arange(len(all_classes)))
  88 +ax.set_yticks(np.arange(len(all_clusters)))
  89 +
  90 +ax.set_xticklabels(all_classes)
  91 +ax.set_yticklabels(all_clusters)
  92 +
  93 +fig.colorbar(im, ax=ax)
  94 +
  95 +plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
  96 +
  97 +# Loop over data dimensions and create text annotations.
  98 +for i in range(count_matrix.shape[0]):
  99 + for j in range(count_matrix.shape[1]):
  100 + text = ax.text(j, i, int(count_matrix[i, j]),
  101 + ha="center", va="center", color="w")
  102 +
  103 +
  104 +ax.set_title("Count Matrix")
  105 +fig.tight_layout()
  106 +plt.savefig(OUTFILE, bbox_inches='tight')