plot-count-matrix.py 3.24 KB
'''
This script aims to plot matrix count.
'''
import argparse
import numpy as np
from data import read_file, index_by_id
from sklearn import preprocessing
import matplotlib.pyplot as plt


# TODO: Avoir la liste des personnages
# TODO: liste des clusters
parser = argparse.ArgumentParser(description="Plot count matrix")
parser.add_argument("clustering", type=str,
                    help="clustering file")
parser.add_argument("classlst", type=str,
                    help="List used for its classes.")
parser.add_argument("lst", type=str,
                    help="list")
parser.add_argument("--outfile", type=str, default="out.pdf",
                    help="output file path")

args = parser.parse_args()
CLUSTERING = args.clustering
CLASS_LST = args.classlst
LST = args.lst
OUTFILE = args.outfile

# -- READ FILES
clustering = read_file(CLUSTERING)
clustering_ind = index_by_id(clustering)

class_lst = read_file(CLASS_LST)
class_lst_ind = index_by_id(class_lst)

lst = read_file(LST)

# -- GET CLASSES AND CLUSTERS
classes = np.asarray([class_lst_ind[x[0][0]][x[0][3]][0][1] for x in lst])
clusters = np.asarray([clustering_ind[x[0][0]][x[0][3]][0][1] for x in lst])

def generate_count_matrix(classes, clusters):
    '''
    Generate matrices for the given set
    Lines are clusters and columns are classes. 
    A cell is contains the number of character occurence 
    on a specific cluster.
    '''

    # Index Classes
    classe_unique = np.unique(classes)
    #all_classes = np.unique(np.concatenate((classe_unique)))
    all_classes = classe_unique

    # Label Encoder for classes
    le = preprocessing.LabelEncoder()
    le.fit(all_classes)

    # Index
    cluster_unique = np.unique(clusters)

    #all_clusters = np.unique(np.concatenate((cluster_unique)))
    all_clusters = cluster_unique
    # Create matrix lin(clust) col(class)
    counts_matrix = np.zeros((np.max(np.asarray(all_clusters, dtype=np.int)) + 1, len(all_classes)))

    for cluster in all_clusters:

        # Il faut d'abord extraire les classes présentes dans ce cluster
        cc = np.extract(np.asarray(clusters) == cluster, np.asarray(classes))

        cc_unique, cc_counts = np.unique(cc, return_counts=True)
        cc_ind = dict(zip(cc_unique, cc_counts))

        for class_ in all_classes:
            class_id = le.transform([class_])[0]
            if class_ in cc_ind:
                counts_matrix[int(cluster)][int(class_id)] = cc_ind[class_]
    return (counts_matrix, all_classes, all_clusters)

count_matrix, all_classes, all_clusters = generate_count_matrix(classes, clusters)

fig, ax = plt.subplots()
fig.set_size_inches(10, len(all_clusters) + 1)
im = ax.imshow(count_matrix)

ax.set_xticks(np.arange(len(all_classes)))
ax.set_yticks(np.arange(len(all_clusters)))

ax.set_xticklabels(all_classes)
ax.set_yticklabels(list(np.arange(len(all_clusters))))

fig.colorbar(im, ax=ax)

plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

# Loop over data dimensions and create text annotations.
for i in range(count_matrix.shape[0]):
    for j in range(count_matrix.shape[1]):
        text = ax.text(j, i, int(count_matrix[i, j]),
                       ha="center", va="center", color="w")


ax.set_title("Count Matrix")
fig.tight_layout()
plt.savefig(OUTFILE, bbox_inches='tight')