Commit 8774b988e43462ca1a216c5a2514b688eada570c
1 parent
81260862fd
Exists in
master
New scatter plot module
Showing 1 changed file with 71 additions and 0 deletions Inline Diff
volia/plot.py
File was created | 1 | import argparse | |
2 | import matplotlib.pyplot as plt | ||
3 | import numpy as np | ||
4 | import pandas as pd | ||
5 | |||
6 | from core.data import read_features, read_lst, read_labels | ||
7 | from utils import SubCommandRunner | ||
8 | |||
9 | |||
10 | def scatter_plot(features: str, labels: str, outfile: str): | ||
11 | """Generate a simple scatter plot. Mainly used for | ||
12 | data visualisation processed with tsne or algorithm like | ||
13 | this later. | ||
14 | |||
15 | Args: | ||
16 | features (str): Features file in 2d or 3d | ||
17 | labels (str): Labels file | ||
18 | outfile (str) : output file | ||
19 | """ | ||
20 | |||
21 | id_to_features = read_features(args.features) | ||
22 | ids = [ key for key in id_to_features.keys() ] | ||
23 | utt2label = read_labels(labels) | ||
24 | |||
25 | features = [ id_to_features[id_] for id_ in ids ] | ||
26 | features = np.vstack(features) | ||
27 | |||
28 | labels_list = [ utt2label[id_][0] for id_ in ids ] | ||
29 | |||
30 | features_T = features.transpose() | ||
31 | print("Number of labels: ", len(np.unique(labels_list))) | ||
32 | df = pd.DataFrame(dict( | ||
33 | x=features_T[0], | ||
34 | y=features_T[1], | ||
35 | label=labels_list)) | ||
36 | |||
37 | groups = df.groupby('label') | ||
38 | |||
39 | # Plot | ||
40 | fig, ax = plt.subplots() | ||
41 | |||
42 | for label, group in groups: | ||
43 | p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label) | ||
44 | ax.legend() | ||
45 | plt.savefig(outfile) | ||
46 | print("Your plot is saved well (no check of this affirmation)") | ||
47 | |||
48 | |||
49 | |||
50 | if __name__ == '__main__': | ||
51 | |||
52 | # Main parser | ||
53 | parser = argparse.ArgumentParser(description="") | ||
54 | subparsers = parser.add_subparsers(title="action") | ||
55 | |||
56 | # with label | ||
57 | parser_scatter = subparsers.add_parser("scatter") | ||
58 | parser_scatter.add_argument("features", type=str, help="define the main features file") | ||
59 | parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element") | ||
60 | parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)") | ||
61 | parser_scatter.set_defaults(which="scatter") | ||
62 | |||
63 | # Parse | ||
64 | args = parser.parse_args() | ||
65 | |||
66 | # Run commands | ||
67 | runner = SubCommandRunner({ | ||
68 | "scatter" : scatter_plot | ||
69 | }) | ||
70 | |||
71 | runner.run(args.which, args.__dict__, remove="which") |