Blame view

scripts/plot/plot-character.py 2.14 KB
44060889b   Mathias   allow the user to...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
  
  import matplotlib.pyplot as plt
  import numpy as np
  import pandas as pd
  import argparse
  from os.path import isfile
  from volia.data_io import read_features, read_lst
  
  
  if __name__ == "__main__":
      # Argparse
      parser = argparse.ArgumentParser(description="Plot points with color for each character")
      parser.add_argument("--features", type=str, help="features file path")
      parser.add_argument("--utt2char", type=str, help="char2utt file path")
      parser.add_argument("--sublist", type=str, default=None, help="white list of ids to take into account")
      parser.add_argument("--outfile", default="out.pdf", type=str, help="")
      parser.add_argument("--title", default="Example of plot", type=str, help="Specify the title")
      args = parser.parse_args()
  
      # List of assertions
      assert args.features, "Need to specify features option"
      assert args.utt2char, "Need to specify char2utt option file"
      assert isfile(args.features), "Features path should point to a file"
      assert isfile(args.utt2char), "char2utt path should point to a file"
      if args.sublist is not None:
          assert isfile(args.sublist), "sublist path should point to a file"
  
  
      id_to_features = read_features(args.features)
  
      ids = []
      if args.sublist is not None:
          print("Using sublist")
          list_ids = read_lst(args.sublist)
          ids = [ key for key in list_ids.keys() ]
      else:
          ids = [ key for key in id_to_features.keys() ]
      
      utt2char = read_lst(args.utt2char)
      
      features = [ id_to_features[id_] for id_ in ids ]
      features = np.vstack(features)
  
      characters_list = [ utt2char[id_][0] for id_ in ids ]
  
      features_T = features.transpose()
      print("Number of characters: ", len(np.unique(characters_list)))
      df = pd.DataFrame(dict(
          x=features_T[0],
          y=features_T[1],
          character=characters_list))
  
      groups = df.groupby('character')
  
      # Plot
      fig, ax = plt.subplots()
  
      for character, group in groups:
          p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=character)
      ax.legend()
      plt.savefig(args.outfile)
      print("Your plot is saved well (no check of this affirmation)")