modify_ctm_edits.py 19.9 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
#!/usr/bin/env python3

# Copyright 2016   Vimal Manohar
#           2016   Johns Hopkins University (author: Daniel Povey)
# Apache 2.0

from __future__ import print_function
import argparse
import logging
import sys
from collections import defaultdict

"""
This script reads and writes the 'ctm-edits' file that is
produced by get_ctm_edits.py.

It modifies the ctm-edits so that non-scored words
are not counted as errors: for instance, if there are things like
[COUGH] and [NOISE] in the transcript, deletions, insertions and
substitutions involving them are allowed, and we modify the reference
to correspond to the hypothesis.

If you supply the <lang> directory (the one that corresponds to
how you decoded the data) to this script, it assumes that the <lang>
directory contains phones/align_lexicon.int, and it uses this to work
out a reasonable guess of the non-scored phones, based on which have
a single-word pronunciation that maps to a silence phone.
It then uses the words.txt to work out the written form of those words.

Alternatively, you may specify a file containing the non-scored words one
per line, with the --non-scored-words option.

Non-scored words that were deleted (i.e. they were in the ref but not the
hyp) are simply removed from the ctm.  For non-scored words that
were inserted or substituted, we change the reference word to match the
hyp word, but instead of marking the operation as 'cor' (correct), we
mark it as 'fix' (fixed), so that it will not be positively counted as a correct
word for purposes of finding the optimal segment boundaries.

e.g.
<file-id> <channel> <start-time> <duration> <conf> <hyp-word> <ref-word> <edit-type>
[note: the <channel> will always be 1].

AJJacobs_2007P-0001605-0003029 1 0 0.09 <eps> 1.0 <eps> sil
AJJacobs_2007P-0001605-0003029 1 0.09 0.15 i 1.0 i cor
AJJacobs_2007P-0001605-0003029 1 0.24 0.25 thought 1.0 thought cor
AJJacobs_2007P-0001605-0003029 1 0.49 0.14 i'd 1.0 i'd cor
AJJacobs_2007P-0001605-0003029 1 0.63 0.22 tell 1.0 tell cor
AJJacobs_2007P-0001605-0003029 1 0.85 0.11 you 1.0 you cor
AJJacobs_2007P-0001605-0003029 1 0.96 0.05 a 1.0 a cor
AJJacobs_2007P-0001605-0003029 1 1.01 0.24 little 1.0 little cor
AJJacobs_2007P-0001605-0003029 1 1.25 0.5 about 1.0 about cor
AJJacobs_2007P-0001605-0003029 1 1.75 0.48 [UH] 1.0 [UH] cor
"""

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s [%(filename)s:%(lineno)s - '
                              '%(funcName)s - %(levelname)s ] %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)


parser = argparse.ArgumentParser(
    description = "This program modifies the reference in the ctm-edits which "
    "is output by steps/cleanup/internal/get_ctm_edits.py, to allow insertions, deletions and "
    "substitutions of non-scored words, and [if --allow-repetitions=true], "
    "duplications of single words or pairs of scored words (to account for dysfluencies "
    "that were not transcribed).  Note: deletions and substitutions of non-scored words "
    "after the reference is corrected, will be marked as operation 'fix' rather than "
    "'cor' (correct) so that the downstream processing knows that this was not in "
    "the original reference.  Also by defaults tags non-scored words as such when "
    "they are correct; see the --tag-non-scored option.")

parser.add_argument("--verbose", type = int, default = 1,
                    choices=[0,1,2,3],
                    help = "Verbose level, higher = more verbose output")
parser.add_argument("--allow-repetitions", type = str, default = 'true',
                    choices=['true','false'],
                    help = "If true, allow repetitions in the transcript of one or "
                    "two-word sequences: for instance if the ref says 'i' but "
                    "the hyp says 'i i', or the ref says 'but then' and the hyp says "
                    "'but then but then', fix the reference accordingly.  Intervening "
                    "non-scored words are allowed between the repetitions.  These "
                    "fixes will be marked as 'cor', not as 'fix', since there is "
                    "generally no way to tell which repetition was the 'real' one "
                    "(and since we're generally confident that such things were "
                    "actually uttered).")
parser.add_argument("non_scored_words_in", metavar = "<non-scored-words-file>",
                    help="Filename of file containing a list of non-scored words, "
                    "one per line. See steps/cleanup/get_nonscored_words.py.")
parser.add_argument("ctm_edits_in", metavar = "<ctm-edits-in>",
                    help = "Filename of input ctm-edits file. "
                    "Use /dev/stdin for standard input.")
parser.add_argument("ctm_edits_out", metavar = "<ctm-edits-out>",
                    help = "Filename of output ctm-edits file. "
                    "Use /dev/stdout for standard output.")

args = parser.parse_args()



def ReadNonScoredWords(non_scored_words_file):
    global non_scored_words
    try:
        f = open(non_scored_words_file, encoding='utf-8')
    except:
        sys.exit("modify_ctm_edits.py: error opening file: "
                 "--non-scored-words=" + non_scored_words_file)
    for line in f.readlines():
        a = line.split()
        if not len(line.split()) == 1:
            sys.exit("modify_ctm_edits.py: bad line in non-scored-words "
                     "file {0}: {1}".format(non_scored_words_file, line))
        non_scored_words.add(a[0])
    f.close()



# The ctm-edits file format is as follows [note: file-id is really utterance-id
# in this context].
# <file-id> <channel> <start-time> <duration> <conf> <hyp-word> <ref-word> <edit>
# e.g.:
# AJJacobs_2007P-0001605-0003029 1 0 0.09 <eps> 1.0 <eps> sil
# AJJacobs_2007P-0001605-0003029 1 0.09 0.15 i 1.0 i cor
# ...
# This function processes a single line of ctm-edits input for fixing
# "non-scored" words.  The input 'a' is the split line as an array of fields.
# It modifies the object 'a'.   This function returns the modified array,
# and please note that it is destructive of its input 'a'.
# If it returnso the empty array then the line is to be deleted.
def ProcessLineForNonScoredWords(a):
    global num_lines, num_correct_lines, ref_change_stats
    try:
        assert len(a) == 8
        num_lines += 1
        # we could do:
        # [ file, channel, start, duration, hyp_word, confidence, ref_word, edit_type ] = a
        duration = a[3]
        hyp_word = a[4]
        ref_word = a[6]
        edit_type = a[7]
        if edit_type == 'ins':
            assert ref_word == '<eps>'
            if hyp_word in non_scored_words:
                # insert this non-scored word into the reference.
                ref_change_stats[ref_word + ' -> ' + hyp_word] += 1
                ref_word = hyp_word
                edit_type = 'fix'
        elif edit_type == 'del':
            assert hyp_word == '<eps>' and float(duration) == 0.0
            if ref_word in non_scored_words:
                ref_change_stats[ref_word + ' -> ' + hyp_word] += 1
                return []
        elif edit_type == 'sub':
            assert hyp_word != '<eps>'
            if hyp_word in non_scored_words and ref_word in non_scored_words:
                # we also allow replacing one non-scored word with another.
                ref_change_stats[ref_word + ' -> ' + hyp_word] += 1
                ref_word = hyp_word
                edit_type = 'fix'
        else:
            assert edit_type == 'cor' or edit_type == 'sil'
            num_correct_lines += 1

        a[4] = hyp_word
        a[6] = ref_word
        a[7] = edit_type
        return a

    except Exception:
        logger.error("bad line in ctm-edits input: "
                     "{0}".format(a))
        raise RuntimeError

# This function processes the split lines of one utterance (as a
# list of lists of fields), to allow repetitions of words, so if the
# reference says 'i' but the hyp says 'i i', or the ref says
# 'you know' and the hyp says 'you know you know', we change the
# ref to match.
# It returns the modified list-of-lists [but note that the input
# is actually modified].
def ProcessUtteranceForRepetitions(split_lines_of_utt):
    global non_scored_words, repetition_stats
    # The array 'selected_lines' will contain the indexes of of selected
    # elements of 'split_lines_of_utt'.  Consider split_line =
    # split_lines_of_utt[i].  If the hyp and ref words in split_line are both
    # either '<eps>' or non-scoreable words, we discard the index.
    # Otherwise we put it into selected_lines.
    selected_line_indexes = []
    # selected_edits will contain, for each element of selected_line_indexes, the
    # corresponding edit_type from the original utterance previous to
    # this function call ('cor', 'ins', etc.).
    #
    # As a special case, if there was a substitution ('sub') where the
    # reference word was a non-scored word and the hyp word was a real word,
    # we mark it in this array as 'ins', because for purposes of this algorithm
    # it behaves the same as an insertion.
    #
    # Whenever we do any operation that will change the reference, we change
    # all the selected_edits in the array to None so that they won't match
    # any further operations.
    selected_edits = []
    # selected_hyp_words will contain, for each element of selected_line_indexes, the
    # corresponding hyp_word.
    selected_hyp_words = []

    for i in range(len(split_lines_of_utt)):
        split_line = split_lines_of_utt[i]
        hyp_word = split_line[4]
        ref_word = split_line[6]
        # keep_this_line will be True if we are going to keep this line in the
        # 'selected lines' for further processing of repetitions.  We only
        # eliminate lines involving non-scored words or epsilon in both hyp
        # and reference position
        # [note: epsilon in hyp position for non-empty segments indicates
        #  optional-silence, and it does make sense to make this 'invisible',
        #  just like non-scored words, for the purposes of this code.]
        keep_this_line = True
        if (hyp_word == '<eps>' or hyp_word in non_scored_words) and \
           (ref_word == '<eps>' or ref_word in non_scored_words):
            keep_this_line = False
        if keep_this_line:
            selected_line_indexes.append(i)
            edit_type = split_line[7]
            if edit_type == 'sub' and ref_word in non_scored_words:
                assert not hyp_word in non_scored_words
                # For purposes of this algorithm, substitution of, say,
                # '[COUGH]' by 'hello' behaves like an insertion of 'hello',
                # since we're willing to remove the '[COUGH]' from the
                # transript.
                edit_type = 'ins'
            selected_edits.append(edit_type)
            selected_hyp_words.append(hyp_word)

    # indexes_to_fix will be a list of indexes into 'selected_indexes' where we
    # plan to fix the ref to match the hyp.
    indexes_to_fix = []

    # This loop scans for, and fixes, two-word insertions that follow,
    # or precede, the corresponding correct words.
    for i in range(0, len(selected_line_indexes) - 3):
        this_indexes = selected_line_indexes[i:i+4]
        this_hyp_words = selected_hyp_words[i:i+4]

        if this_hyp_words[0] == this_hyp_words[2] and \
           this_hyp_words[1] == this_hyp_words[3] and \
           this_hyp_words[0] != this_hyp_words[1]:
            # if the hyp words were of the form [ 'a', 'b', 'a', 'b' ]...
            this_edits = selected_edits[i:i+4]
            if this_edits == [ 'cor', 'cor', 'ins', 'ins' ] or \
                    this_edits == [ 'ins', 'ins', 'cor', 'cor' ]:
                if this_edits[0] == 'cor':
                    indexes_to_fix += [ i+2, i+3 ]
                else:
                    indexes_to_fix += [ i, i+1 ]

                # the next line prevents this region of the text being used
                # in any further edits.
                selected_edits[i:i+4] = [ None, None, None, None ]
                word_pair = this_hyp_words[0] + ' '  + this_hyp_words[1]
                # e.g. word_pair = 'hi there'
                # add 2 because these stats are of words.
                repetition_stats[word_pair] += 2
                # the next line prevents this region of the text being used
                # in any further edits.
                selected_edits[i:i+4] = [ None, None, None, None ]

    # This loop scans for, and fixes, one-word insertions that follow,
    # or precede, the corresponding correct words.
    for i in range(0, len(selected_line_indexes) - 1):
        this_indexes = selected_line_indexes[i:i+2]
        this_hyp_words = selected_hyp_words[i:i+2]

        if this_hyp_words[0] == this_hyp_words[1]:
            # if the hyp words were of the form [ 'a', 'a' ]...
            this_edits = selected_edits[i:i+2]
            if this_edits == [ 'cor', 'ins' ] or this_edits == [ 'ins', 'cor' ]:
                if this_edits[0] == 'cor':
                    indexes_to_fix.append(i+1)
                else:
                    indexes_to_fix.append(i)
                repetition_stats[this_hyp_words[0]] += 1
                # the next line prevents this region of the text being used
                # in any further edits.
                selected_edits[i:i+2] = [ None, None ]

    for i in indexes_to_fix:
        j = selected_line_indexes[i]
        split_line = split_lines_of_utt[j]
        ref_word = split_line[6]
        hyp_word = split_line[4]
        assert ref_word == '<eps>' or ref_word in non_scored_words
        # we replace reference with the decoded word, which will be a
        # repetition.
        split_line[6] = hyp_word
        split_line[7] = 'cor'

    return split_lines_of_utt


# note: split_lines_of_utt is a list of lists, one per line, each containing the
# sequence of fields.
# Returns the same format of data after processing.
def ProcessUtterance(split_lines_of_utt):
    new_split_lines_of_utt = []
    for split_line in split_lines_of_utt:
        new_split_line = ProcessLineForNonScoredWords(split_line)
        if new_split_line != []:
            new_split_lines_of_utt.append(new_split_line)
    if args.allow_repetitions == 'true':
        new_split_lines_of_utt = ProcessUtteranceForRepetitions(new_split_lines_of_utt)
    return new_split_lines_of_utt


def ProcessData():
    try:
        f_in = open(args.ctm_edits_in, encoding='utf-8')
    except:
        sys.exit("modify_ctm_edits.py: error opening ctm-edits input "
                 "file {0}".format(args.ctm_edits_in))
    try:
        f_out = open(args.ctm_edits_out, 'w', encoding='utf-8')
    except:
        sys.exit("modify_ctm_edits.py: error opening ctm-edits output "
                 "file {0}".format(args.ctm_edits_out))
    num_lines_processed = 0


    # Most of what we're doing in the lines below is splitting the input lines
    # and grouping them per utterance, before giving them to ProcessUtterance()
    # and then printing the modified lines.
    first_line = f_in.readline()
    if first_line == '':
        sys.exit("modify_ctm_edits.py: empty input")
    split_pending_line = first_line.split()
    if len(split_pending_line) == 0:
        sys.exit("modify_ctm_edits.py: bad input line " + first_line)
    cur_utterance = split_pending_line[0]
    split_lines_of_cur_utterance = []

    while True:
        if len(split_pending_line) == 0 or split_pending_line[0] != cur_utterance:
            split_lines_of_cur_utterance = ProcessUtterance(split_lines_of_cur_utterance)
            for split_line in split_lines_of_cur_utterance:
                print(' '.join(split_line), file = f_out)
            split_lines_of_cur_utterance = []
            if len(split_pending_line) == 0:
                break
            else:
                cur_utterance = split_pending_line[0]

        split_lines_of_cur_utterance.append(split_pending_line)
        next_line = f_in.readline()
        split_pending_line = next_line.split()
        if len(split_pending_line) == 0:
            if next_line != '':
                sys.exit("modify_ctm_edits.py: got an empty or whitespace input line")
    try:
        f_out.close()
    except:
        sys.exit("modify_ctm_edits.py: error closing ctm-edits output "
                 "(broken pipe or full disk?)")

def PrintNonScoredStats():
    if args.verbose < 1:
        return
    if num_lines == 0:
        print("modify_ctm_edits.py: processed no input.", file = sys.stderr)
    num_lines_modified = sum(ref_change_stats.values())
    num_incorrect_lines = num_lines - num_correct_lines
    percent_lines_incorrect= '%.2f' % (num_incorrect_lines * 100.0 / num_lines)
    percent_modified = '%.2f' % (num_lines_modified * 100.0 / num_lines);
    if num_incorrect_lines > 0:
        percent_of_incorrect_modified = '%.2f' % (num_lines_modified * 100.0 /
                                                  num_incorrect_lines)
    else:
        percent_of_incorrect_modified = float('nan')
    print("modify_ctm_edits.py: processed {0} lines of ctm ({1}% of which incorrect), "
          "of which {2} were changed fixing the reference for non-scored words "
          "({3}% of lines, or {4}% of incorrect lines)".format(
            num_lines, percent_lines_incorrect, num_lines_modified,
            percent_modified, percent_of_incorrect_modified),
          file = sys.stderr)

    keys = sorted(ref_change_stats.keys(), reverse=True,
                  key = lambda x: ref_change_stats[x])
    num_keys_to_print = 40 if args.verbose >= 2 else 10

    print("modify_ctm_edits.py: most common edits (as percentages "
          "of all such edits) are:\n" +
          ('\n'.join([ '%s [%.2f%%]' % (k, ref_change_stats[k]*100.0/num_lines_modified)
                     for k in keys[0:num_keys_to_print]]))
          + '\n...'if num_keys_to_print < len(keys) else '',
          file = sys.stderr)


def PrintRepetitionStats():
    if args.verbose < 1 or sum(repetition_stats.values()) == 0:
        return
    num_lines_modified = sum(repetition_stats.values())
    num_incorrect_lines = num_lines - num_correct_lines
    percent_lines_incorrect= '%.2f' % (num_incorrect_lines * 100.0 / num_lines)
    percent_modified = '%.2f' % (num_lines_modified * 100.0 / num_lines);
    if num_incorrect_lines > 0:
        percent_of_incorrect_modified = '%.2f' % (num_lines_modified * 100.0 /
                                                  num_incorrect_lines)
    else:
        percent_of_incorrect_modified = float('nan')
    print("modify_ctm_edits.py: processed {0} lines of ctm ({1}% of which incorrect), "
          "of which {2} were changed fixing the reference for repetitions ({3}% of "
          "lines, or {4}% of incorrect lines)".format(
            num_lines, percent_lines_incorrect, num_lines_modified,
            percent_modified, percent_of_incorrect_modified),
          file = sys.stderr)

    keys = sorted(repetition_stats.keys(), reverse=True,
                  key = lambda x: repetition_stats[x])
    num_keys_to_print = 40 if args.verbose >= 2 else 10

    print("modify_ctm_edits.py: most common repetitions inserted into reference (as percentages "
          "of all words fixed in this way) are:\n" +
          ('\n'.join([ '%s [%.2f%%]' % (k, repetition_stats[k]*100.0/num_lines_modified)
                     for k in keys[0:num_keys_to_print]]))
          + '\n...' if num_keys_to_print < len(keys) else '',
          file = sys.stderr)


non_scored_words = set()
ReadNonScoredWords(args.non_scored_words_in)

num_lines = 0
num_correct_lines = 0
# ref_change_stats will be a map from a string like
# 'foo -> bar' to an integer count; it keeps track of how much we changed
# the reference.
ref_change_stats = defaultdict(int)
# repetition_stats will be a map from strings like
# 'a', or 'a b' (the repeated strings), to an integer count; like
# ref_change_stats, it keeps track of how many changes we made
# in allowing repetitions.
repetition_stats = defaultdict(int)

ProcessData()
PrintNonScoredStats()
PrintRepetitionStats()