Commit 8774b988e43462ca1a216c5a2514b688eada570c

Authored by Quillot Mathias
1 parent 81260862fd
Exists in master

New scatter plot module

Showing 1 changed file with 71 additions and 0 deletions Inline Diff

File was created 1 import argparse
2 import matplotlib.pyplot as plt
3 import numpy as np
4 import pandas as pd
5
6 from core.data import read_features, read_lst, read_labels
7 from utils import SubCommandRunner
8
9
10 def scatter_plot(features: str, labels: str, outfile: str):
11 """Generate a simple scatter plot. Mainly used for
12 data visualisation processed with tsne or algorithm like
13 this later.
14
15 Args:
16 features (str): Features file in 2d or 3d
17 labels (str): Labels file
18 outfile (str) : output file
19 """
20
21 id_to_features = read_features(args.features)
22 ids = [ key for key in id_to_features.keys() ]
23 utt2label = read_labels(labels)
24
25 features = [ id_to_features[id_] for id_ in ids ]
26 features = np.vstack(features)
27
28 labels_list = [ utt2label[id_][0] for id_ in ids ]
29
30 features_T = features.transpose()
31 print("Number of labels: ", len(np.unique(labels_list)))
32 df = pd.DataFrame(dict(
33 x=features_T[0],
34 y=features_T[1],
35 label=labels_list))
36
37 groups = df.groupby('label')
38
39 # Plot
40 fig, ax = plt.subplots()
41
42 for label, group in groups:
43 p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label)
44 ax.legend()
45 plt.savefig(outfile)
46 print("Your plot is saved well (no check of this affirmation)")
47
48
49
50 if __name__ == '__main__':
51
52 # Main parser
53 parser = argparse.ArgumentParser(description="")
54 subparsers = parser.add_subparsers(title="action")
55
56 # with label
57 parser_scatter = subparsers.add_parser("scatter")
58 parser_scatter.add_argument("features", type=str, help="define the main features file")
59 parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element")
60 parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)")
61 parser_scatter.set_defaults(which="scatter")
62
63 # Parse
64 args = parser.parse_args()
65
66 # Run commands
67 runner = SubCommandRunner({
68 "scatter" : scatter_plot
69 })
70
71 runner.run(args.which, args.__dict__, remove="which")