Commit 70fadae5790d879b6539f10698fb6f897851d080
1 parent
62fc82e59a
Exists in
master
Simple comment change
Showing 1 changed file with 1 additions and 1 deletions Inline Diff
volia/plot.py
1 | import argparse | 1 | import argparse |
2 | import matplotlib.pyplot as plt | 2 | import matplotlib.pyplot as plt |
3 | import numpy as np | 3 | import numpy as np |
4 | import pandas as pd | 4 | import pandas as pd |
5 | 5 | ||
6 | from core.data import read_features, read_lst, read_labels | 6 | from core.data import read_features, read_lst, read_labels |
7 | from utils import SubCommandRunner | 7 | from utils import SubCommandRunner |
8 | 8 | ||
9 | 9 | ||
10 | def scatter_plot(features: str, labels: str, outfile: str): | 10 | def scatter_plot(features: str, labels: str, outfile: str): |
11 | """Generate a simple scatter plot. Mainly used for | 11 | """Generate a simple scatter plot. Mainly used for |
12 | data visualisation processed with tsne or algorithm like | 12 | data visualisation processed with tsne or algorithm like |
13 | this later. | 13 | this later. |
14 | 14 | ||
15 | Args: | 15 | Args: |
16 | features (str): Features file in 2d or 3d | 16 | features (str): Features file in 2d or 3d |
17 | labels (str): Labels file | 17 | labels (str): Labels file |
18 | outfile (str) : output file | 18 | outfile (str) : output file |
19 | """ | 19 | """ |
20 | 20 | ||
21 | id_to_features = read_features(args.features) | 21 | id_to_features = read_features(args.features) |
22 | ids = [ key for key in id_to_features.keys() ] | 22 | ids = [ key for key in id_to_features.keys() ] |
23 | utt2label = read_labels(labels) | 23 | utt2label = read_labels(labels) |
24 | 24 | ||
25 | features = [ id_to_features[id_] for id_ in ids ] | 25 | features = [ id_to_features[id_] for id_ in ids ] |
26 | features = np.vstack(features) | 26 | features = np.vstack(features) |
27 | 27 | ||
28 | labels_list = [ utt2label[id_][0] for id_ in ids ] | 28 | labels_list = [ utt2label[id_][0] for id_ in ids ] |
29 | 29 | ||
30 | features_T = features.transpose() | 30 | features_T = features.transpose() |
31 | print("Number of labels: ", len(np.unique(labels_list))) | 31 | print("Number of labels: ", len(np.unique(labels_list))) |
32 | df = pd.DataFrame(dict( | 32 | df = pd.DataFrame(dict( |
33 | x=features_T[0], | 33 | x=features_T[0], |
34 | y=features_T[1], | 34 | y=features_T[1], |
35 | label=labels_list)) | 35 | label=labels_list)) |
36 | 36 | ||
37 | groups = df.groupby('label') | 37 | groups = df.groupby('label') |
38 | 38 | ||
39 | # Plot | 39 | # Plot |
40 | fig, ax = plt.subplots() | 40 | fig, ax = plt.subplots() |
41 | 41 | ||
42 | for label, group in groups: | 42 | for label, group in groups: |
43 | p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label) | 43 | p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label) |
44 | ax.legend() | 44 | ax.legend() |
45 | plt.savefig(outfile) | 45 | plt.savefig(outfile) |
46 | print("Your plot is saved well (no check of this affirmation)") | 46 | print("Your plot is saved well (no check of this affirmation)") |
47 | 47 | ||
48 | 48 | ||
49 | 49 | ||
50 | if __name__ == '__main__': | 50 | if __name__ == '__main__': |
51 | 51 | ||
52 | # Main parser | 52 | # Main parser |
53 | parser = argparse.ArgumentParser(description="") | 53 | parser = argparse.ArgumentParser(description="") |
54 | subparsers = parser.add_subparsers(title="action") | 54 | subparsers = parser.add_subparsers(title="action") |
55 | 55 | ||
56 | # with label | 56 | # scatter with labels |
57 | parser_scatter = subparsers.add_parser("scatter") | 57 | parser_scatter = subparsers.add_parser("scatter") |
58 | parser_scatter.add_argument("features", type=str, help="define the main features file") | 58 | parser_scatter.add_argument("features", type=str, help="define the main features file") |
59 | parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element") | 59 | parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element") |
60 | parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)") | 60 | parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)") |
61 | parser_scatter.set_defaults(which="scatter") | 61 | parser_scatter.set_defaults(which="scatter") |
62 | 62 | ||
63 | # Parse | 63 | # Parse |
64 | args = parser.parse_args() | 64 | args = parser.parse_args() |
65 | 65 | ||
66 | # Run commands | 66 | # Run commands |
67 | runner = SubCommandRunner({ | 67 | runner = SubCommandRunner({ |
68 | "scatter" : scatter_plot | 68 | "scatter" : scatter_plot |
69 | }) | 69 | }) |
70 | 70 | ||
71 | runner.run(args.which, args.__dict__, remove="which") | 71 | runner.run(args.which, args.__dict__, remove="which") |