Blame view
egs/wsj/s5/steps/cleanup/internal/taint_ctm_edits.py
10.9 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
#!/usr/bin/env python3 # Copyright 2016 Vimal Manohar # 2016 Johns Hopkins University (author: Daniel Povey) # Apache 2.0 from __future__ import print_function import sys, operator, argparse, os from collections import defaultdict import io sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf8") # This script reads and writes the 'ctm-edits' file that is # produced by get_ctm_edits.py. # # It is to be applied after modify_ctm_edits.py. Its function is to add, in # certain circumstances, an optional extra field with the word 'tainted' to the # ctm-edits format, e.g an input line like: # # AJJacobs_2007P-0001605-0003029 1 0 0.09 <eps> 1.0 <eps> sil # might become: # AJJacobs_2007P-0001605-0003029 1 0 0.09 <eps> 1.0 <eps> sil tainted # # It also deletes certain lines, representing deletions, from the ctm (if they # were next to taintable lines... their presence could then be inferred from the # 'tainted' flag). # # You should interpret the 'tainted' flag as "we're not sure what's going on here; # don't trust this." # # One of the problem this script is trying to solve is that if we have errors # that are adjacent to silence or non-scored words # it's not at all clear whether the silence or non-scored words were really such, # or might have contained actual words. # Also, if we have words in the reference that were realized as '<unk>' in the # hypothesis, and they are adjacent to errors, it's almost always the case # that the '<unk>' doesn't really correspond to the word in the reference, so # we mark these as 'tainted'. # # The rule for tainting is quite simple; see the code. parser = argparse.ArgumentParser( description = "This program modifies the ctm-edits format to identify " "silence and 'fixed' non-scored-word lines, and lines where the hyp is " "<unk> and the reference is a real but OOV word, where there is a relatively " "high probability that something is going wrong so we shouldn't trust " "this line. It adds the field 'tainted' to such " "lines. Lines in the ctm representing deletions from the reference will " "be removed if they have 'tainted' adjacent lines (since it won't be clear " "where such reference words were really realized, if at all). " "See comments at the top of the script for more information.") parser.add_argument("--verbose", type = int, default = 1, choices=[0,1,2,3], help = "Verbose level, higher = more verbose output") parser.add_argument("--remove-deletions", type=str, default="true", choices=["true", "false"], help = "Remove deletions next to taintable lines") parser.add_argument("ctm_edits_in", metavar = "<ctm-edits-in>", help = "Filename of input ctm-edits file. " "Use /dev/stdin for standard input.") parser.add_argument("ctm_edits_out", metavar = "<ctm-edits-out>", help = "Filename of output ctm-edits file. " "Use /dev/stdout for standard output.") args = parser.parse_args() args.remove_deletions = bool(args.remove_deletions == "true") # This function is the core of the program, that does the tainting and # removes some lines representing deletions. # split_lines_of_utt is a list of lists, one per line, each containing the # sequence of fields. Returns the same format of data after processing to add # the 'tainted' field. Note: this function is destructive of its input; the # input will not have the same value afterwards. def ProcessUtterance(split_lines_of_utt, remove_deletions=True): global num_lines_of_type, num_tainted_lines, \ num_del_lines_giving_taint, num_sub_lines_giving_taint, \ num_ins_lines_giving_taint # work out whether each line is taintable [i.e. silence or fix or unk replacing # real-word]. taintable = [ False ] * len(split_lines_of_utt) for i in range(len(split_lines_of_utt)): edit_type = split_lines_of_utt[i][7] if edit_type == 'sil' or edit_type == 'fix': taintable[i] = True elif edit_type == 'cor' and split_lines_of_utt[i][4] != split_lines_of_utt[i][6]: # this is the case when <unk> replaces a real word that was out of # the vocabulary; we mark it as correct because such words do # translate to <unk> if we don't have a pronunciations. However we # don't have good confidence that the alignments of such words are # accurate if they are adjacent to errors. taintable[i] = True for i in range(len(split_lines_of_utt)): edit_type = split_lines_of_utt[i][7] num_lines_of_type[edit_type] += 1 if edit_type == 'del' or edit_type == 'sub' or edit_type == 'ins': tainted_an_adjacent_line = False # First go backwards tainting lines j = i - 1 while j >= 0 and taintable[j]: tainted_an_adjacent_line = True if len(split_lines_of_utt[j]) == 8: num_tainted_lines += 1 split_lines_of_utt[j].append('tainted') j -= 1 # Next go forwards tainting lines j = i + 1 while j < len(split_lines_of_utt) and taintable[j]: tainted_an_adjacent_line = True if len(split_lines_of_utt[j]) == 8: num_tainted_lines += 1 split_lines_of_utt[j].append('tainted') j += 1 if tainted_an_adjacent_line: if edit_type == 'del': if remove_deletions: split_lines_of_utt[i][7] = 'remove-this-line' num_del_lines_giving_taint += 1 elif edit_type == 'sub': num_sub_lines_giving_taint += 1 else: num_ins_lines_giving_taint += 1 new_split_lines_of_utt = [] for i in range(len(split_lines_of_utt)): if (not remove_deletions or split_lines_of_utt[i][7] != 'remove-this-line'): new_split_lines_of_utt.append(split_lines_of_utt[i]) return new_split_lines_of_utt def ProcessData(): try: f_in = open(args.ctm_edits_in, encoding="utf8") except: sys.exit("taint_ctm_edits.py: error opening ctm-edits input " "file {0}".format(args.ctm_edits_in)) try: f_out = open(args.ctm_edits_out, 'w', encoding="utf8") except: sys.exit("taint_ctm_edits.py: error opening ctm-edits output " "file {0}".format(args.ctm_edits_out)) num_lines_processed = 0 # Most of what we're doing in the lines below is splitting the input lines # and grouping them per utterance, before giving them to ProcessUtterance() # and then printing the modified lines. first_line = f_in.readline() if first_line == '': sys.exit("taint_ctm_edits.py: empty input") split_pending_line = first_line.split() if len(split_pending_line) == 0: sys.exit("taint_ctm_edits.py: bad input line " + first_line) cur_utterance = split_pending_line[0] split_lines_of_cur_utterance = [] while True: if len(split_pending_line) == 0 or split_pending_line[0] != cur_utterance: split_lines_of_cur_utterance = ProcessUtterance( split_lines_of_cur_utterance, args.remove_deletions) for split_line in split_lines_of_cur_utterance: print(' '.join(split_line), file = f_out) split_lines_of_cur_utterance = [] if len(split_pending_line) == 0: break else: cur_utterance = split_pending_line[0] split_lines_of_cur_utterance.append(split_pending_line) next_line = f_in.readline() split_pending_line = next_line.split() if len(split_pending_line) == 0: if next_line != '': sys.exit("taint_ctm_edits.py: got an empty or whitespace input line") try: f_out.close() except: sys.exit("taint_ctm_edits.py: error closing ctm-edits output " "(broken pipe or full disk?)") def PrintNonScoredStats(): if args.verbose < 1: return if num_lines == 0: print("taint_ctm_edits.py: processed no input.", file = sys.stderr) num_lines_modified = sum(ref_change_stats.values()) num_incorrect_lines = num_lines - num_correct_lines percent_lines_incorrect= '%.2f' % (num_incorrect_lines * 100.0 / num_lines) percent_modified = '%.2f' % (num_lines_modified * 100.0 / num_lines); percent_of_incorrect_modified = '%.2f' % (num_lines_modified * 100.0 / num_incorrect_lines) print("taint_ctm_edits.py: processed {0} lines of ctm ({1}% of which incorrect), " "of which {2} were changed fixing the reference for non-scored words " "({3}% of lines, or {4}% of incorrect lines)".format( num_lines, percent_lines_incorrect, num_lines_modified, percent_modified, percent_of_incorrect_modified), file = sys.stderr) keys = sorted(list(ref_change_stats.keys()), reverse=True, key = lambda x: ref_change_stats[x]) num_keys_to_print = 40 if args.verbose >= 2 else 10 print("taint_ctm_edits.py: most common edits (as percentages " "of all such edits) are: " + (' '.join([ '%s [%.2f%%]' % (k, ref_change_stats[k]*100.0/num_lines_modified) for k in keys[0:num_keys_to_print]])) + ' ...'if num_keys_to_print < len(keys) else '', file = sys.stderr) def PrintStats(): tot_lines = sum(num_lines_of_type.values()) if args.verbose < 1 or tot_lines == 0: return print("taint_ctm_edits.py: processed {0} input lines, whose edit-types were: ".format(tot_lines) + ', '.join([ '%s = %.2f%%' % (k, num_lines_of_type[k] * 100.0 / tot_lines) for k in sorted(list(num_lines_of_type.keys()), reverse = True, key = lambda k: num_lines_of_type[k]) ]), file = sys.stderr) del_giving_taint_percent = num_del_lines_giving_taint * 100.0 / tot_lines sub_giving_taint_percent = num_sub_lines_giving_taint * 100.0 / tot_lines ins_giving_taint_percent = num_ins_lines_giving_taint * 100.0 / tot_lines tainted_lines_percent = num_tainted_lines * 100.0 / tot_lines print("taint_ctm_edits.py: as a percentage of all lines, (%.2f%%, %.2f%%, %.2f%%) were " "(deletions, substitutions, insertions) that tainted adjacent lines. %.2f%% of all " "lines were tainted." % (del_giving_taint_percent, sub_giving_taint_percent, ins_giving_taint_percent, tainted_lines_percent), file = sys.stderr) # num_lines_of_type will map from line-type ('cor', 'sub', etc.) to count. num_lines_of_type = defaultdict(int) num_tainted_lines = 0 num_del_lines_giving_taint = 0 num_sub_lines_giving_taint = 0 num_ins_lines_giving_taint = 0 ProcessData() PrintStats() |