Commit 44060889b7cb301838ffdb16f60110216e857ba3

Authored by Mathias
1 parent e403ed5fb6
Exists in master

allow the user to plot points on space (2d or 3d), coloring these with respect t…

…o the character label

Showing 1 changed file with 62 additions and 0 deletions Side-by-side Diff

scripts/plot/plot-character.py
  1 +
  2 +import matplotlib.pyplot as plt
  3 +import numpy as np
  4 +import pandas as pd
  5 +import argparse
  6 +from os.path import isfile
  7 +from volia.data_io import read_features, read_lst
  8 +
  9 +
  10 +if __name__ == "__main__":
  11 + # Argparse
  12 + parser = argparse.ArgumentParser(description="Plot points with color for each character")
  13 + parser.add_argument("--features", type=str, help="features file path")
  14 + parser.add_argument("--utt2char", type=str, help="char2utt file path")
  15 + parser.add_argument("--sublist", type=str, default=None, help="white list of ids to take into account")
  16 + parser.add_argument("--outfile", default="out.pdf", type=str, help="")
  17 + parser.add_argument("--title", default="Example of plot", type=str, help="Specify the title")
  18 + args = parser.parse_args()
  19 +
  20 + # List of assertions
  21 + assert args.features, "Need to specify features option"
  22 + assert args.utt2char, "Need to specify char2utt option file"
  23 + assert isfile(args.features), "Features path should point to a file"
  24 + assert isfile(args.utt2char), "char2utt path should point to a file"
  25 + if args.sublist is not None:
  26 + assert isfile(args.sublist), "sublist path should point to a file"
  27 +
  28 +
  29 + id_to_features = read_features(args.features)
  30 +
  31 + ids = []
  32 + if args.sublist is not None:
  33 + print("Using sublist")
  34 + list_ids = read_lst(args.sublist)
  35 + ids = [ key for key in list_ids.keys() ]
  36 + else:
  37 + ids = [ key for key in id_to_features.keys() ]
  38 +
  39 + utt2char = read_lst(args.utt2char)
  40 +
  41 + features = [ id_to_features[id_] for id_ in ids ]
  42 + features = np.vstack(features)
  43 +
  44 + characters_list = [ utt2char[id_][0] for id_ in ids ]
  45 +
  46 + features_T = features.transpose()
  47 + print("Number of characters: ", len(np.unique(characters_list)))
  48 + df = pd.DataFrame(dict(
  49 + x=features_T[0],
  50 + y=features_T[1],
  51 + character=characters_list))
  52 +
  53 + groups = df.groupby('character')
  54 +
  55 + # Plot
  56 + fig, ax = plt.subplots()
  57 +
  58 + for character, group in groups:
  59 + p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=character)
  60 + ax.legend()
  61 + plt.savefig(args.outfile)
  62 + print("Your plot is saved well (no check of this affirmation)")