diff --git a/scripts/plot/plot-character.py b/scripts/plot/plot-character.py new file mode 100644 index 0000000..bfb98d7 --- /dev/null +++ b/scripts/plot/plot-character.py @@ -0,0 +1,62 @@ + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import argparse +from os.path import isfile +from volia.data_io import read_features, read_lst + + +if __name__ == "__main__": + # Argparse + parser = argparse.ArgumentParser(description="Plot points with color for each character") + parser.add_argument("--features", type=str, help="features file path") + parser.add_argument("--utt2char", type=str, help="char2utt file path") + parser.add_argument("--sublist", type=str, default=None, help="white list of ids to take into account") + parser.add_argument("--outfile", default="out.pdf", type=str, help="") + parser.add_argument("--title", default="Example of plot", type=str, help="Specify the title") + args = parser.parse_args() + + # List of assertions + assert args.features, "Need to specify features option" + assert args.utt2char, "Need to specify char2utt option file" + assert isfile(args.features), "Features path should point to a file" + assert isfile(args.utt2char), "char2utt path should point to a file" + if args.sublist is not None: + assert isfile(args.sublist), "sublist path should point to a file" + + + id_to_features = read_features(args.features) + + ids = [] + if args.sublist is not None: + print("Using sublist") + list_ids = read_lst(args.sublist) + ids = [ key for key in list_ids.keys() ] + else: + ids = [ key for key in id_to_features.keys() ] + + utt2char = read_lst(args.utt2char) + + features = [ id_to_features[id_] for id_ in ids ] + features = np.vstack(features) + + characters_list = [ utt2char[id_][0] for id_ in ids ] + + features_T = features.transpose() + print("Number of characters: ", len(np.unique(characters_list))) + df = pd.DataFrame(dict( + x=features_T[0], + y=features_T[1], + character=characters_list)) + + groups = df.groupby('character') + + # Plot + fig, ax = plt.subplots() + + for character, group in groups: + p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=character) + ax.legend() + plt.savefig(args.outfile) + print("Your plot is saved well (no check of this affirmation)")