Blame view
egs/wsj/s5/steps/segmentation/internal/merge_targets.py
8.33 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
#!/usr/bin/env python3 # Copyright 2017 Vimal Manohar # Apache 2.0 """ This script merges targets created from multiple sources (systems) into single targets matrices. Usage: merge_targets.py [options] <pasted-targets> <out-targets> e.g.: paste-feats scp:targets1.scp scp:targets2.scp ark,t:- | merge_targets.py --dim=3 - - | copy-feats ark,t:- ark:- <pasted-targets> is matrix archive with matrices corresponding to targets from multiple sources appended together using paste-feats. The column dimension is num-sources * dim, which dim is specified by --dim option. """ import argparse import logging import numpy as np import sys sys.path.insert(0, 'steps') import libs.common as common_lib logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) handler = logging.StreamHandler() handler.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - " "%(funcName)s - %(levelname)s ] %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) def get_args(): parser = argparse.ArgumentParser( description=""" This script merges targets created from multiple sources (systems) into single targets matrices. Usage: merge_targets.py [options] <pasted-targets> <out-targets> e.g.: paste-feats scp:targets1.scp scp:targets2.scp ark,t:- | merge_targets.py --dim=3 - - | copy-feats ark,t:- ark:- """, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("--weights", type=str, default="", help="A comma-separated list of weights corresponding " "to each targets source being combined. " "Weights will be normalized internally to sum-to-one.") parser.add_argument("--dim", type=int, default=3, help="Number of columns corresponding to each " "target matrix") parser.add_argument("--remove-mismatch-frames", type=str, default=False, choices=["true", "false"], action=common_lib.StrToBoolAction, help="If true, the mismatch frames are removed by " "setting targets to 0 in the following cases: " "a) If none of the sources have a column with value " "> 0.5 " "b) If two sources have columns with value > 0.5, but " "they occur at different indexes e.g. silence prob is " "> 0.5 for the targets from alignment, and speech prob " "> 0.5 for the targets from decoding.") parser.add_argument("pasted_targets", type=str, help="Input target matrices with columns appended " "together using paste-feats. Its column dimension is " "num-sources * dim, which dim is specified by --dim " "option.") parser.add_argument("out_targets", type=str, help="Output target matrices") args = parser.parse_args() if args.weights != "": args.weights = [float(x) for x in args.weights.split(",")] weights_sum = sum(args.weights) args.weights = [x / weights_sum for x in args.weights] else: args.weights = None return args def should_remove_frame(row, dim): """Returns True if the frame needs to be removed. Input: row -- a list of values (of dimension num-sources x dim) corresponding to the targets for one of the frames dim -- Usually 3. The number of sources can be computed as the len(row) / dim. The frame is determined to be removed in the following cases: 1) None of the values > 0.5. 2) More than one source has best value >= 0.5, but at different indexes in the source. e.g. [ 1 0 0 0.6 0 0.4 0 0 0 ] # kept because 1 and 0.6 are both > 0.5 # at the same class namely 0 # source[0] = [ 1 0 0 ] # source[1] = [ 0.6 0 0.4 ] # source[2] = [ 0 0 0 ] e.g. [ 0 0 0 0.4 0 0.6 1 0 0 ] # removed because source[1] has best value # 0.6 > 0.5 at class 2 and source[2] has # best value 1 > 0.5 at class 0. # source[0] = [ 0 0 0 ] # source[1] = [ 0.4 0 0.6 ] # source[2] = [ 0 0 0 ] """ assert len(row) % dim == 0 num_sources = len(row) // dim max_idx = np.argmax(row) max_val = row[max_idx] if max_val < 0.5: # All the values < 0.5. So we are not confident of any sources. # Remove frame. return True best_source = max_idx // dim best_class = max_idx % dim confident_in_source = [] # List of length num_sources # Element 'i' is 1, # if the best value for the source 'i' is > 0.5 best_values_for_source = [] # Element 'i' is a pair (value, class), # where 'class' is argmax over the scores # corresponding to the source 'i' and # 'value' is the corresponding score. for source_idx in range(num_sources): idx = np.argmax(row[(source_idx * dim): ((source_idx+1) * dim)]) val = row[source_idx * dim + idx] confident_in_source.append(bool(val > 0.5)) best_values_for_source.append((val, idx)) if sum(confident_in_source) == 1: # We are confident in only one source. Keep frame. return False for source_idx in range(num_sources): if source_idx == best_source: assert confident_in_source[source_idx] continue if not confident_in_source[source_idx]: continue else: # We are confident in a source other than the 'best_source'. # If it's index is different from the 'best_class', then it is # a mismatch and the frame must be removed. val, idx = best_values_for_source[source_idx] assert val > 0.5 if idx != best_class: return True return False def run(args): num_done = 0 with common_lib.smart_open(args.pasted_targets) as targets_reader, \ common_lib.smart_open(args.out_targets, 'w') as targets_writer: for key, mat in common_lib.read_mat_ark(targets_reader): mat = np.matrix(mat) if mat.shape[1] % args.dim != 0: raise RuntimeError( "For utterance {utt} in {f}, num-columns {nc} " "is not a multiple of dim {dim}" "".format(utt=key, f=args.pasted_targets.name, nc=mat.shape[1], dim=args.dim)) num_sources = mat.shape[1] // args.dim out_mat = np.matrix(np.zeros([mat.shape[0], args.dim])) if args.remove_mismatch_frames: for n in range(mat.shape[0]): if should_remove_frame(mat[n, :].getA()[0], args.dim): out_mat[n, :] = np.zeros([1, args.dim]) else: for i in range(num_sources): out_mat[n, :] += ( mat[n, (i * args.dim) : ((i+1) * args.dim)] * (1.0 if args.weights is None else args.weights[i])) else: # Just interpolate the targets for i in range(num_sources): out_mat += ( mat[:, (i * args.dim) : ((i+1) * args.dim)] * (1.0 if args.weights is None else args.weights[i])) common_lib.write_matrix_ascii(targets_writer, out_mat.tolist(), key=key) num_done += 1 logger.info("Merged {num_done} target matrices" "".format(num_done=num_done)) if num_done == 0: raise RuntimeError def main(): args = get_args() try: run(args) except Exception: raise if __name__ == '__main__': main() |