Commit 40650f20d7dedad02fc01bb39966cb01f045f91a

Authored by quillotm
1 parent d87f627177
Exists in master

Preparing interactive scatter plot in order to generate cool visualizations.

Showing 1 changed file with 22 additions and 1 deletions Inline Diff

1 import argparse 1 import argparse
2 import matplotlib.pyplot as plt 2 import matplotlib.pyplot as plt
3 import numpy as np 3 import numpy as np
4 import pandas as pd 4 import pandas as pd
5 5
6 from core.data import read_features, read_lst, read_labels 6 from core.data import read_features, read_lst, read_labels
7 from utils import SubCommandRunner 7 from utils import SubCommandRunner
8 8
9 9
10 def scatter_plot(features: str, labels: str, outfile: str): 10 def scatter_plot(features: str, labels: str, outfile: str):
11 """Generate a simple scatter plot. Mainly used for 11 """Generate a simple scatter plot. Mainly used for
12 data visualisation processed with tsne or algorithm like 12 data visualisation processed with tsne or algorithm like
13 this later. 13 this later.
14 14
15 Args: 15 Args:
16 features (str): Features file in 2d or 3d 16 features (str): Features file in 2d or 3d
17 labels (str): Labels file 17 labels (str): Labels file
18 outfile (str) : output file 18 outfile (str) : output file
19 """ 19 """
20 20
21 id_to_features = read_features(args.features) 21 id_to_features = read_features(args.features)
22 ids = [ key for key in id_to_features.keys() ] 22 ids = [ key for key in id_to_features.keys() ]
23 utt2label = read_labels(labels) 23 utt2label = read_labels(labels)
24 24
25 features = [ id_to_features[id_] for id_ in ids ] 25 features = [ id_to_features[id_] for id_ in ids ]
26 features = np.vstack(features) 26 features = np.vstack(features)
27 27
28 labels_list = [ utt2label[id_][0] for id_ in ids ] 28 labels_list = [ utt2label[id_][0] for id_ in ids ]
29 29
30 features_T = features.transpose() 30 features_T = features.transpose()
31 print("Number of labels: ", len(np.unique(labels_list))) 31 print("Number of labels: ", len(np.unique(labels_list)))
32 df = pd.DataFrame(dict( 32 df = pd.DataFrame(dict(
33 x=features_T[0], 33 x=features_T[0],
34 y=features_T[1], 34 y=features_T[1],
35 label=labels_list)) 35 label=labels_list))
36 36
37 groups = df.groupby('label') 37 groups = df.groupby('label')
38 38
39 # Plot 39 # Plot
40 fig, ax = plt.subplots() 40 fig, ax = plt.subplots()
41 41
42 for label, group in groups: 42 for label, group in groups:
43 p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label) 43 p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label)
44 ax.legend() 44 ax.legend()
45 plt.savefig(outfile) 45 plt.savefig(outfile)
46 print("Your plot is saved well (no check of this affirmation)") 46 print("Your plot is saved well (no check of this affirmation)")
47 47
48 48
49 def interactive_scatter_plot(features: str, labels: str, outdir: str):
50 """Generate an interactive scatter plot in 3D. Mainly used for
51 data visualisation processed with tsne or algorithm like
52 this later. This visualization is generated in Web files.
49 53
54 Args:
55 features (str): Features file in 2d or 3d
56 labels (str): Labels file
57 outdir (str) : output directory where Web files are saved
58 """
59 pass
60
61
50 if __name__ == '__main__': 62 if __name__ == '__main__':
51 63
52 # Main parser 64 # Main parser
53 parser = argparse.ArgumentParser(description="") 65 parser = argparse.ArgumentParser(description="")
54 subparsers = parser.add_subparsers(title="action") 66 subparsers = parser.add_subparsers(title="action")
55 67
56 # scatter with labels 68 # scatter with labels
57 parser_scatter = subparsers.add_parser("scatter") 69 parser_scatter = subparsers.add_parser("scatter")
58 parser_scatter.add_argument("features", type=str, help="define the main features file") 70 parser_scatter.add_argument("features", type=str, help="define the main features file")
59 parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element") 71 parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element")
60 parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)") 72 parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)")
61 parser_scatter.set_defaults(which="scatter") 73 parser_scatter.set_defaults(which="scatter")
62 74
75
76 # interactive scatter
77 parser_interactive_scatter = subparsers.add_parser("interactive_scatter")
78 parser_interactive_scatter.add_argument("features", type=str, help="features files with only 3D will be converted into csv")
79 parser_scatter.add_argument("--labels", default=None, type=str, help="Specify the labels of each utterance/element")
80 parser_scatter.add_argument("--outdir", default=".out", type=str, help="output directoy where static web and data files are saved.")
81 parser_scatter.set_defaults(which="interactive_scatter")
82
63 # Parse 83 # Parse
64 args = parser.parse_args() 84 args = parser.parse_args()
65 85
66 # Run commands 86 # Run commands
67 runner = SubCommandRunner({ 87 runner = SubCommandRunner({
68 "scatter" : scatter_plot 88 "scatter" : scatter_plot,
89 "interactive_scatter" : interactive_scatter_plot
69 }) 90 })
70 91
71 runner.run(args.which, args.__dict__, remove="which") 92 runner.run(args.which, args.__dict__, remove="which")