plot-character.py 2.14 KB
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import argparse
from os.path import isfile
from volia.data_io import read_features, read_lst


if __name__ == "__main__":
    # Argparse
    parser = argparse.ArgumentParser(description="Plot points with color for each character")
    parser.add_argument("--features", type=str, help="features file path")
    parser.add_argument("--utt2char", type=str, help="char2utt file path")
    parser.add_argument("--sublist", type=str, default=None, help="white list of ids to take into account")
    parser.add_argument("--outfile", default="out.pdf", type=str, help="")
    parser.add_argument("--title", default="Example of plot", type=str, help="Specify the title")
    args = parser.parse_args()

    # List of assertions
    assert args.features, "Need to specify features option"
    assert args.utt2char, "Need to specify char2utt option file"
    assert isfile(args.features), "Features path should point to a file"
    assert isfile(args.utt2char), "char2utt path should point to a file"
    if args.sublist is not None:
        assert isfile(args.sublist), "sublist path should point to a file"


    id_to_features = read_features(args.features)

    ids = []
    if args.sublist is not None:
        print("Using sublist")
        list_ids = read_lst(args.sublist)
        ids = [ key for key in list_ids.keys() ]
    else:
        ids = [ key for key in id_to_features.keys() ]
    
    utt2char = read_lst(args.utt2char)
    
    features = [ id_to_features[id_] for id_ in ids ]
    features = np.vstack(features)

    characters_list = [ utt2char[id_][0] for id_ in ids ]

    features_T = features.transpose()
    print("Number of characters: ", len(np.unique(characters_list)))
    df = pd.DataFrame(dict(
        x=features_T[0],
        y=features_T[1],
        character=characters_list))

    groups = df.groupby('character')

    # Plot
    fig, ax = plt.subplots()

    for character, group in groups:
        p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=character)
    ax.legend()
    plt.savefig(args.outfile)
    print("Your plot is saved well (no check of this affirmation)")