plot.py
2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from core.data import read_features, read_lst, read_labels
from utils import SubCommandRunner
def scatter_plot(features: str, labels: str, outfile: str):
"""Generate a simple scatter plot. Mainly used for
data visualisation processed with tsne or algorithm like
this later.
Args:
features (str): Features file in 2d or 3d
labels (str): Labels file
outfile (str) : output file
"""
id_to_features = read_features(args.features)
ids = [ key for key in id_to_features.keys() ]
utt2label = read_labels(labels)
features = [ id_to_features[id_] for id_ in ids ]
features = np.vstack(features)
labels_list = [ utt2label[id_][0] for id_ in ids ]
features_T = features.transpose()
print("Number of labels: ", len(np.unique(labels_list)))
df = pd.DataFrame(dict(
x=features_T[0],
y=features_T[1],
label=labels_list))
groups = df.groupby('label')
# Plot
fig, ax = plt.subplots()
for label, group in groups:
p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label)
ax.legend()
plt.savefig(outfile)
print("Your plot is saved well (no check of this affirmation)")
if __name__ == '__main__':
# Main parser
parser = argparse.ArgumentParser(description="")
subparsers = parser.add_subparsers(title="action")
# scatter with labels
parser_scatter = subparsers.add_parser("scatter")
parser_scatter.add_argument("features", type=str, help="define the main features file")
parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element")
parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)")
parser_scatter.set_defaults(which="scatter")
# Parse
args = parser.parse_args()
# Run commands
runner = SubCommandRunner({
"scatter" : scatter_plot
})
runner.run(args.which, args.__dict__, remove="which")