Blame view
bin/plot-count-matrix.py
3.24 KB
b7530e269 allow you to plot... |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
''' This script aims to plot matrix count. ''' import argparse import numpy as np from data import read_file, index_by_id from sklearn import preprocessing import matplotlib.pyplot as plt # TODO: Avoir la liste des personnages # TODO: liste des clusters parser = argparse.ArgumentParser(description="Plot count matrix") parser.add_argument("clustering", type=str, help="clustering file") parser.add_argument("classlst", type=str, help="List used for its classes.") parser.add_argument("lst", type=str, help="list") parser.add_argument("--outfile", type=str, default="out.pdf", help="output file path") args = parser.parse_args() CLUSTERING = args.clustering CLASS_LST = args.classlst LST = args.lst OUTFILE = args.outfile # -- READ FILES clustering = read_file(CLUSTERING) clustering_ind = index_by_id(clustering) class_lst = read_file(CLASS_LST) class_lst_ind = index_by_id(class_lst) lst = read_file(LST) # -- GET CLASSES AND CLUSTERS classes = np.asarray([class_lst_ind[x[0][0]][x[0][3]][0][1] for x in lst]) clusters = np.asarray([clustering_ind[x[0][0]][x[0][3]][0][1] for x in lst]) def generate_count_matrix(classes, clusters): ''' Generate matrices for the given set Lines are clusters and columns are classes. A cell is contains the number of character occurence on a specific cluster. ''' # Index Classes classe_unique = np.unique(classes) #all_classes = np.unique(np.concatenate((classe_unique))) all_classes = classe_unique # Label Encoder for classes le = preprocessing.LabelEncoder() le.fit(all_classes) # Index cluster_unique = np.unique(clusters) #all_clusters = np.unique(np.concatenate((cluster_unique))) all_clusters = cluster_unique # Create matrix lin(clust) col(class) counts_matrix = np.zeros((np.max(np.asarray(all_clusters, dtype=np.int)) + 1, len(all_classes))) for cluster in all_clusters: # Il faut d'abord extraire les classes présentes dans ce cluster cc = np.extract(np.asarray(clusters) == cluster, np.asarray(classes)) cc_unique, cc_counts = np.unique(cc, return_counts=True) cc_ind = dict(zip(cc_unique, cc_counts)) for class_ in all_classes: class_id = le.transform([class_])[0] if class_ in cc_ind: counts_matrix[int(cluster)][int(class_id)] = cc_ind[class_] return (counts_matrix, all_classes, all_clusters) count_matrix, all_classes, all_clusters = generate_count_matrix(classes, clusters) fig, ax = plt.subplots() fig.set_size_inches(10, len(all_clusters) + 1) im = ax.imshow(count_matrix) ax.set_xticks(np.arange(len(all_classes))) ax.set_yticks(np.arange(len(all_clusters))) ax.set_xticklabels(all_classes) |
87c56690a repair error from... |
91 |
ax.set_yticklabels(list(np.arange(len(all_clusters)))) |
b7530e269 allow you to plot... |
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
fig.colorbar(im, ax=ax) plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. for i in range(count_matrix.shape[0]): for j in range(count_matrix.shape[1]): text = ax.text(j, i, int(count_matrix[i, j]), ha="center", va="center", color="w") ax.set_title("Count Matrix") fig.tight_layout() plt.savefig(OUTFILE, bbox_inches='tight') |