plot.py
3.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from core.data import read_features, read_lst, read_labels
from utils import SubCommandRunner
def scatter_plot(features: str, labels: str, outfile: str):
"""Generate a simple scatter plot. Mainly used for
data visualisation processed with tsne or algorithm like
this later.
Args:
features (str): Features file in 2d or 3d
labels (str): Labels file
outfile (str) : output file
"""
id_to_features = read_features(args.features)
ids = [ key for key in id_to_features.keys() ]
utt2label = read_labels(labels)
features = [ id_to_features[id_] for id_ in ids ]
features = np.vstack(features)
labels_list = [ utt2label[id_][0] for id_ in ids ]
features_T = features.transpose()
print("Number of labels: ", len(np.unique(labels_list)))
df = pd.DataFrame(dict(
x=features_T[0],
y=features_T[1],
label=labels_list))
groups = df.groupby('label')
# Plot
fig, ax = plt.subplots()
for label, group in groups:
p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label)
ax.legend()
plt.savefig(outfile)
print("Your plot is saved well (no check of this affirmation)")
def interactive_scatter_plot(features: str, labels: str, outdir: str):
"""Generate an interactive scatter plot in 3D. Mainly used for
data visualisation processed with tsne or algorithm like
this later. This visualization is generated in Web files.
Args:
features (str): Features file in 2d or 3d
labels (str): Labels file
outdir (str) : output directory where Web files are saved
"""
pass
if __name__ == '__main__':
# Main parser
parser = argparse.ArgumentParser(description="")
subparsers = parser.add_subparsers(title="action")
# scatter with labels
parser_scatter = subparsers.add_parser("scatter")
parser_scatter.add_argument("features", type=str, help="define the main features file")
parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element")
parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)")
parser_scatter.set_defaults(which="scatter")
# interactive scatter
parser_interactive_scatter = subparsers.add_parser("interactive_scatter")
parser_interactive_scatter.add_argument("features", type=str, help="features files with only 3D will be converted into csv")
parser_scatter.add_argument("--labels", default=None, type=str, help="Specify the labels of each utterance/element")
parser_scatter.add_argument("--outdir", default=".out", type=str, help="output directoy where static web and data files are saved.")
parser_scatter.set_defaults(which="interactive_scatter")
# Parse
args = parser.parse_args()
# Run commands
runner = SubCommandRunner({
"scatter" : scatter_plot,
"interactive_scatter" : interactive_scatter_plot
})
runner.run(args.which, args.__dict__, remove="which")