Commit 40650f20d7dedad02fc01bb39966cb01f045f91a
1 parent
d87f627177
Exists in
master
Preparing interactive scatter plot in order to generate cool visualizations.
Showing 1 changed file with 22 additions and 1 deletions Inline Diff
volia/plot.py
1 | import argparse | 1 | import argparse |
2 | import matplotlib.pyplot as plt | 2 | import matplotlib.pyplot as plt |
3 | import numpy as np | 3 | import numpy as np |
4 | import pandas as pd | 4 | import pandas as pd |
5 | 5 | ||
6 | from core.data import read_features, read_lst, read_labels | 6 | from core.data import read_features, read_lst, read_labels |
7 | from utils import SubCommandRunner | 7 | from utils import SubCommandRunner |
8 | 8 | ||
9 | 9 | ||
10 | def scatter_plot(features: str, labels: str, outfile: str): | 10 | def scatter_plot(features: str, labels: str, outfile: str): |
11 | """Generate a simple scatter plot. Mainly used for | 11 | """Generate a simple scatter plot. Mainly used for |
12 | data visualisation processed with tsne or algorithm like | 12 | data visualisation processed with tsne or algorithm like |
13 | this later. | 13 | this later. |
14 | 14 | ||
15 | Args: | 15 | Args: |
16 | features (str): Features file in 2d or 3d | 16 | features (str): Features file in 2d or 3d |
17 | labels (str): Labels file | 17 | labels (str): Labels file |
18 | outfile (str) : output file | 18 | outfile (str) : output file |
19 | """ | 19 | """ |
20 | 20 | ||
21 | id_to_features = read_features(args.features) | 21 | id_to_features = read_features(args.features) |
22 | ids = [ key for key in id_to_features.keys() ] | 22 | ids = [ key for key in id_to_features.keys() ] |
23 | utt2label = read_labels(labels) | 23 | utt2label = read_labels(labels) |
24 | 24 | ||
25 | features = [ id_to_features[id_] for id_ in ids ] | 25 | features = [ id_to_features[id_] for id_ in ids ] |
26 | features = np.vstack(features) | 26 | features = np.vstack(features) |
27 | 27 | ||
28 | labels_list = [ utt2label[id_][0] for id_ in ids ] | 28 | labels_list = [ utt2label[id_][0] for id_ in ids ] |
29 | 29 | ||
30 | features_T = features.transpose() | 30 | features_T = features.transpose() |
31 | print("Number of labels: ", len(np.unique(labels_list))) | 31 | print("Number of labels: ", len(np.unique(labels_list))) |
32 | df = pd.DataFrame(dict( | 32 | df = pd.DataFrame(dict( |
33 | x=features_T[0], | 33 | x=features_T[0], |
34 | y=features_T[1], | 34 | y=features_T[1], |
35 | label=labels_list)) | 35 | label=labels_list)) |
36 | 36 | ||
37 | groups = df.groupby('label') | 37 | groups = df.groupby('label') |
38 | 38 | ||
39 | # Plot | 39 | # Plot |
40 | fig, ax = plt.subplots() | 40 | fig, ax = plt.subplots() |
41 | 41 | ||
42 | for label, group in groups: | 42 | for label, group in groups: |
43 | p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label) | 43 | p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label) |
44 | ax.legend() | 44 | ax.legend() |
45 | plt.savefig(outfile) | 45 | plt.savefig(outfile) |
46 | print("Your plot is saved well (no check of this affirmation)") | 46 | print("Your plot is saved well (no check of this affirmation)") |
47 | 47 | ||
48 | 48 | ||
49 | def interactive_scatter_plot(features: str, labels: str, outdir: str): | ||
50 | """Generate an interactive scatter plot in 3D. Mainly used for | ||
51 | data visualisation processed with tsne or algorithm like | ||
52 | this later. This visualization is generated in Web files. | ||
49 | 53 | ||
54 | Args: | ||
55 | features (str): Features file in 2d or 3d | ||
56 | labels (str): Labels file | ||
57 | outdir (str) : output directory where Web files are saved | ||
58 | """ | ||
59 | pass | ||
60 | |||
61 | |||
50 | if __name__ == '__main__': | 62 | if __name__ == '__main__': |
51 | 63 | ||
52 | # Main parser | 64 | # Main parser |
53 | parser = argparse.ArgumentParser(description="") | 65 | parser = argparse.ArgumentParser(description="") |
54 | subparsers = parser.add_subparsers(title="action") | 66 | subparsers = parser.add_subparsers(title="action") |
55 | 67 | ||
56 | # scatter with labels | 68 | # scatter with labels |
57 | parser_scatter = subparsers.add_parser("scatter") | 69 | parser_scatter = subparsers.add_parser("scatter") |
58 | parser_scatter.add_argument("features", type=str, help="define the main features file") | 70 | parser_scatter.add_argument("features", type=str, help="define the main features file") |
59 | parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element") | 71 | parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element") |
60 | parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)") | 72 | parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)") |
61 | parser_scatter.set_defaults(which="scatter") | 73 | parser_scatter.set_defaults(which="scatter") |
62 | 74 | ||
75 | |||
76 | # interactive scatter | ||
77 | parser_interactive_scatter = subparsers.add_parser("interactive_scatter") | ||
78 | parser_interactive_scatter.add_argument("features", type=str, help="features files with only 3D will be converted into csv") | ||
79 | parser_scatter.add_argument("--labels", default=None, type=str, help="Specify the labels of each utterance/element") | ||
80 | parser_scatter.add_argument("--outdir", default=".out", type=str, help="output directoy where static web and data files are saved.") | ||
81 | parser_scatter.set_defaults(which="interactive_scatter") | ||
82 | |||
63 | # Parse | 83 | # Parse |
64 | args = parser.parse_args() | 84 | args = parser.parse_args() |
65 | 85 | ||
66 | # Run commands | 86 | # Run commands |
67 | runner = SubCommandRunner({ | 87 | runner = SubCommandRunner({ |
68 | "scatter" : scatter_plot | 88 | "scatter" : scatter_plot, |
89 | "interactive_scatter" : interactive_scatter_plot | ||
69 | }) | 90 | }) |
70 | 91 | ||
71 | runner.run(args.which, args.__dict__, remove="which") | 92 | runner.run(args.which, args.__dict__, remove="which") |