Commit 70fadae5790d879b6539f10698fb6f897851d080

Authored by Quillot Mathias
1 parent 62fc82e59a
Exists in master

Simple comment change

Showing 1 changed file with 1 additions and 1 deletions Inline Diff

1 import argparse 1 import argparse
2 import matplotlib.pyplot as plt 2 import matplotlib.pyplot as plt
3 import numpy as np 3 import numpy as np
4 import pandas as pd 4 import pandas as pd
5 5
6 from core.data import read_features, read_lst, read_labels 6 from core.data import read_features, read_lst, read_labels
7 from utils import SubCommandRunner 7 from utils import SubCommandRunner
8 8
9 9
10 def scatter_plot(features: str, labels: str, outfile: str): 10 def scatter_plot(features: str, labels: str, outfile: str):
11 """Generate a simple scatter plot. Mainly used for 11 """Generate a simple scatter plot. Mainly used for
12 data visualisation processed with tsne or algorithm like 12 data visualisation processed with tsne or algorithm like
13 this later. 13 this later.
14 14
15 Args: 15 Args:
16 features (str): Features file in 2d or 3d 16 features (str): Features file in 2d or 3d
17 labels (str): Labels file 17 labels (str): Labels file
18 outfile (str) : output file 18 outfile (str) : output file
19 """ 19 """
20 20
21 id_to_features = read_features(args.features) 21 id_to_features = read_features(args.features)
22 ids = [ key for key in id_to_features.keys() ] 22 ids = [ key for key in id_to_features.keys() ]
23 utt2label = read_labels(labels) 23 utt2label = read_labels(labels)
24 24
25 features = [ id_to_features[id_] for id_ in ids ] 25 features = [ id_to_features[id_] for id_ in ids ]
26 features = np.vstack(features) 26 features = np.vstack(features)
27 27
28 labels_list = [ utt2label[id_][0] for id_ in ids ] 28 labels_list = [ utt2label[id_][0] for id_ in ids ]
29 29
30 features_T = features.transpose() 30 features_T = features.transpose()
31 print("Number of labels: ", len(np.unique(labels_list))) 31 print("Number of labels: ", len(np.unique(labels_list)))
32 df = pd.DataFrame(dict( 32 df = pd.DataFrame(dict(
33 x=features_T[0], 33 x=features_T[0],
34 y=features_T[1], 34 y=features_T[1],
35 label=labels_list)) 35 label=labels_list))
36 36
37 groups = df.groupby('label') 37 groups = df.groupby('label')
38 38
39 # Plot 39 # Plot
40 fig, ax = plt.subplots() 40 fig, ax = plt.subplots()
41 41
42 for label, group in groups: 42 for label, group in groups:
43 p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label) 43 p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label)
44 ax.legend() 44 ax.legend()
45 plt.savefig(outfile) 45 plt.savefig(outfile)
46 print("Your plot is saved well (no check of this affirmation)") 46 print("Your plot is saved well (no check of this affirmation)")
47 47
48 48
49 49
50 if __name__ == '__main__': 50 if __name__ == '__main__':
51 51
52 # Main parser 52 # Main parser
53 parser = argparse.ArgumentParser(description="") 53 parser = argparse.ArgumentParser(description="")
54 subparsers = parser.add_subparsers(title="action") 54 subparsers = parser.add_subparsers(title="action")
55 55
56 # with label 56 # scatter with labels
57 parser_scatter = subparsers.add_parser("scatter") 57 parser_scatter = subparsers.add_parser("scatter")
58 parser_scatter.add_argument("features", type=str, help="define the main features file") 58 parser_scatter.add_argument("features", type=str, help="define the main features file")
59 parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element") 59 parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element")
60 parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)") 60 parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)")
61 parser_scatter.set_defaults(which="scatter") 61 parser_scatter.set_defaults(which="scatter")
62 62
63 # Parse 63 # Parse
64 args = parser.parse_args() 64 args = parser.parse_args()
65 65
66 # Run commands 66 # Run commands
67 runner = SubCommandRunner({ 67 runner = SubCommandRunner({
68 "scatter" : scatter_plot 68 "scatter" : scatter_plot
69 }) 69 })
70 70
71 runner.run(args.which, args.__dict__, remove="which") 71 runner.run(args.which, args.__dict__, remove="which")