merge_targets.py
8.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
#!/usr/bin/env python3
# Copyright 2017 Vimal Manohar
# Apache 2.0
"""
This script merges targets created from multiple sources (systems) into
single targets matrices.
Usage: merge_targets.py [options] <pasted-targets> <out-targets>
e.g.: paste-feats scp:targets1.scp scp:targets2.scp ark,t:- | merge_targets.py --dim=3 - - | copy-feats ark,t:- ark:-
<pasted-targets> is matrix archive with matrices corresponding to
targets from multiple sources appended together using paste-feats.
The column dimension is num-sources * dim, which dim is specified by --dim
option.
"""
import argparse
import logging
import numpy as np
import sys
sys.path.insert(0, 'steps')
import libs.common as common_lib
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - "
"%(funcName)s - %(levelname)s ] %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
def get_args():
parser = argparse.ArgumentParser(
description="""
This script merges targets created from multiple sources (systems) into
single targets matrices.
Usage: merge_targets.py [options] <pasted-targets> <out-targets>
e.g.: paste-feats scp:targets1.scp scp:targets2.scp ark,t:- | merge_targets.py --dim=3 - - | copy-feats ark,t:- ark:-
""",
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--weights", type=str, default="",
help="A comma-separated list of weights corresponding "
"to each targets source being combined. "
"Weights will be normalized internally to sum-to-one.")
parser.add_argument("--dim", type=int, default=3,
help="Number of columns corresponding to each "
"target matrix")
parser.add_argument("--remove-mismatch-frames", type=str, default=False,
choices=["true", "false"],
action=common_lib.StrToBoolAction,
help="If true, the mismatch frames are removed by "
"setting targets to 0 in the following cases:\n"
"a) If none of the sources have a column with value "
"> 0.5\n"
"b) If two sources have columns with value > 0.5, but "
"they occur at different indexes e.g. silence prob is "
"> 0.5 for the targets from alignment, and speech prob "
"> 0.5 for the targets from decoding.")
parser.add_argument("pasted_targets", type=str,
help="Input target matrices with columns appended "
"together using paste-feats. Its column dimension is "
"num-sources * dim, which dim is specified by --dim "
"option.")
parser.add_argument("out_targets", type=str,
help="Output target matrices")
args = parser.parse_args()
if args.weights != "":
args.weights = [float(x) for x in args.weights.split(",")]
weights_sum = sum(args.weights)
args.weights = [x / weights_sum for x in args.weights]
else:
args.weights = None
return args
def should_remove_frame(row, dim):
"""Returns True if the frame needs to be removed.
Input:
row -- a list of values (of dimension num-sources x dim) corresponding
to the targets for one of the frames
dim -- Usually 3. The number of sources can be computed as the
len(row) / dim.
The frame is determined to be removed in the following cases:
1) None of the values > 0.5.
2) More than one source has best value >= 0.5, but at different
indexes in the source.
e.g. [ 1 0 0 0.6 0 0.4 0 0 0 ] # kept because 1 and 0.6 are both > 0.5
# at the same class namely 0
# source[0] = [ 1 0 0 ]
# source[1] = [ 0.6 0 0.4 ]
# source[2] = [ 0 0 0 ]
e.g. [ 0 0 0 0.4 0 0.6 1 0 0 ] # removed because source[1] has best value
# 0.6 > 0.5 at class 2 and source[2] has
# best value 1 > 0.5 at class 0.
# source[0] = [ 0 0 0 ]
# source[1] = [ 0.4 0 0.6 ]
# source[2] = [ 0 0 0 ]
"""
assert len(row) % dim == 0
num_sources = len(row) // dim
max_idx = np.argmax(row)
max_val = row[max_idx]
if max_val < 0.5:
# All the values < 0.5. So we are not confident of any sources.
# Remove frame.
return True
best_source = max_idx // dim
best_class = max_idx % dim
confident_in_source = [] # List of length num_sources
# Element 'i' is 1,
# if the best value for the source 'i' is > 0.5
best_values_for_source = [] # Element 'i' is a pair (value, class),
# where 'class' is argmax over the scores
# corresponding to the source 'i' and
# 'value' is the corresponding score.
for source_idx in range(num_sources):
idx = np.argmax(row[(source_idx * dim):
((source_idx+1) * dim)])
val = row[source_idx * dim + idx]
confident_in_source.append(bool(val > 0.5))
best_values_for_source.append((val, idx))
if sum(confident_in_source) == 1:
# We are confident in only one source. Keep frame.
return False
for source_idx in range(num_sources):
if source_idx == best_source:
assert confident_in_source[source_idx]
continue
if not confident_in_source[source_idx]:
continue
else:
# We are confident in a source other than the 'best_source'.
# If it's index is different from the 'best_class', then it is
# a mismatch and the frame must be removed.
val, idx = best_values_for_source[source_idx]
assert val > 0.5
if idx != best_class:
return True
return False
def run(args):
num_done = 0
with common_lib.smart_open(args.pasted_targets) as targets_reader, \
common_lib.smart_open(args.out_targets, 'w') as targets_writer:
for key, mat in common_lib.read_mat_ark(targets_reader):
mat = np.matrix(mat)
if mat.shape[1] % args.dim != 0:
raise RuntimeError(
"For utterance {utt} in {f}, num-columns {nc} "
"is not a multiple of dim {dim}"
"".format(utt=key, f=args.pasted_targets.name,
nc=mat.shape[1], dim=args.dim))
num_sources = mat.shape[1] // args.dim
out_mat = np.matrix(np.zeros([mat.shape[0], args.dim]))
if args.remove_mismatch_frames:
for n in range(mat.shape[0]):
if should_remove_frame(mat[n, :].getA()[0], args.dim):
out_mat[n, :] = np.zeros([1, args.dim])
else:
for i in range(num_sources):
out_mat[n, :] += (
mat[n, (i * args.dim) : ((i+1) * args.dim)]
* (1.0 if args.weights is None
else args.weights[i]))
else:
# Just interpolate the targets
for i in range(num_sources):
out_mat += (
mat[:, (i * args.dim) : ((i+1) * args.dim)]
* (1.0 if args.weights is None else args.weights[i]))
common_lib.write_matrix_ascii(targets_writer, out_mat.tolist(),
key=key)
num_done += 1
logger.info("Merged {num_done} target matrices"
"".format(num_done=num_done))
if num_done == 0:
raise RuntimeError
def main():
args = get_args()
try:
run(args)
except Exception:
raise
if __name__ == '__main__':
main()