Blame view
egs/wsj/s5/steps/cleanup/internal/get_pron_stats.py
11.2 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
#!/usr/bin/env python # Copyright 2016 Xiaohui Zhang # Apache 2.0. from __future__ import print_function from __future__ import division import argparse import sys import warnings # Collect pronounciation stats from a ctm_prons.txt file of the form output # by steps/cleanup/debug_lexicon.sh. This input file has lines of the form: # utt_id word phone1 phone2 .. phoneN # e.g. # foo-bar123-342 hello h eh l l ow # (and this script does require that lines from the same utterance be ordered in # order of time). # The output of this program is word pronunciation stats of the form: # count word phone1 .. phoneN # e.g.: # 24.0 hello h ax l l ow # This program uses various heuristics to account for the fact that in the input ctm_prons.txt # file may not always be well aligned. As a result of some of these heuristics the counts will # not always be integers. def GetArgs(): parser = argparse.ArgumentParser(description = "Accumulate pronounciation statistics from " "a ctm_prons.txt file.", epilog = "See steps/cleanup/debug_lexicon.sh for example") parser.add_argument("ctm_prons_file", metavar = "<ctm-prons-file>", type = str, help = "File containing word-pronounciation alignments obtained from a ctm file; " "It represents phonetic decoding results, aligned with word boundaries obtained" "from forced alignments." "each line must be <utt_id> <word> <phones>") parser.add_argument("silence_file", metavar = "<silphone-file>", type = str, help = "File containing a list of silence phones.") parser.add_argument("optional_silence_file", metavar = "<optional_silence>", type = str, help = "File containing the optional silence phone. We'll be replacing empty prons by this," "because empty prons would cause a problem for lattice word alignment.") parser.add_argument("non_scored_words_file", metavar = "<non-scored-words-file>", type = str, help = "File containing a list of non-scored words.") parser.add_argument("stats_file", metavar = "<stats-file>", type = str, help = "Write accumulated statitistics to this file; each line represents how many times " "a specific word-pronunciation pair appears in the phonetic decoding results (ctm_pron_file)." "each line is <count> <word> <phones>") print (' '.join(sys.argv), file=sys.stderr) args = parser.parse_args() args = CheckArgs(args) return args def CheckArgs(args): if args.ctm_prons_file == "-": args.ctm_prons_file_handle = sys.stdin else: args.ctm_prons_file_handle = open(args.ctm_prons_file) args.non_scored_words_file_handle = open(args.non_scored_words_file) args.silence_file_handle = open(args.silence_file) args.optional_silence_file_handle = open(args.optional_silence_file) if args.stats_file == "-": args.stats_file_handle = sys.stdout else: args.stats_file_handle = open(args.stats_file, "w") return args def ReadEntries(file_handle): entries = set() for line in file_handle: entries.add(line.strip()) return entries # Basically, this function generates an "info" list from a ctm_prons file. # Each entry in the list represents the pronounciation candidate(s) of a word. # For each non-<eps> word, the entry is a list: [utt_id, word, set(pronunciation_candidates)]. e.g: # [911Mothers_2010W-0010916-0012901-1, other, set('AH DH ER', 'AH DH ER K AH N')] # For each <eps>, we split the phones it aligns to into two parts: "nonsil_left", # which includes phones before the first silphone, and "nonsil_right", which includes # phones after the last silphone. For example, for <eps> : 'V SIL B AH SIL', # nonsil_left is 'V' and nonsil_right is empty ''. After processing an <eps> entry # in ctm_prons, we put it in "info" as an entry: [utt_id, word, nonsil_right] # only if it's nonsil_right segment is not empty, which may be used when processing # the next word. # # Normally, one non-<eps> word is only aligned to one pronounciation candidate. However # when there is a preceding/following <eps>, like in the following example, we # assume the phones aligned to <eps> should be statistically distributed # to its neighboring words (BTW we assume there are no consecutive <eps> within an utterance.) # Thus we append the "nonsil_left" segment of these phones to the pronounciation # of the preceding word, if the last phone of this pronounciation is not a silence phone, # Similarly we can add a pron candidate to the following word. # # For example, for the following part of a ctm_prons file: # 911Mothers_2010W-0010916-0012901-1 other AH DH ER # 911Mothers_2010W-0010916-0012901-1 <eps> K AH N SIL B # 911Mothers_2010W-0010916-0012901-1 because IH K HH W AA Z AH # 911Mothers_2010W-0010916-0012901-1 <eps> V SIL # 911Mothers_2010W-0010916-0012901-1 when W EH N # 911Mothers_2010W-0010916-0012901-1 people P IY P AH L # 911Mothers_2010W-0010916-0012901-1 <eps> SIL # 911Mothers_2010W-0010916-0012901-1 heard HH ER # 911Mothers_2010W-0010916-0012901-1 <eps> D # 911Mothers_2010W-0010916-0012901-1 that SIL DH AH T # 911Mothers_2010W-0010916-0012901-1 my M AY # # The corresponding segment in the "info" list is: # [911Mothers_2010W-0010916-0012901-1, other, set('AH DH ER', 'AH DH ER K AH N')] # [911Mothers_2010W-0010916-0012901-1, <eps>, 'B' # [911Mothers_2010W-0010916-0012901-1, because, set('IH K HH W AA Z AH', 'B IH K HH W AA Z AH', 'IH K HH W AA Z AH V', 'B IH K HH W AA Z AH V')] # [911Mothers_2010W-0010916-0012901-1, when, set('W EH N')] # [911Mothers_2010W-0010916-0012901-1, people, set('P IY P AH L')] # [911Mothers_2010W-0010916-0012901-1, <eps>, 'D'] # [911Mothers_2010W-0010916-0012901-1, that, set('SIL DH AH T')] # [911Mothers_2010W-0010916-0012901-1, my, set('M AY')] # # Then we accumulate pronouciation stats from "info". Basically, for each occurence # of a word, each pronounciation candidate gets equal soft counts. e.g. In the above # example, each pron candidate of "because" gets a count of 1/4. The stats is stored # in a dictionary (word, pron) : count. def GetStatsFromCtmProns(silphones, optional_silence, non_scored_words, ctm_prons_file_handle): info = [] for line in ctm_prons_file_handle.readlines(): splits = line.strip().split() utt = splits[0] word = splits[1] phones = splits[2:] if phones == []: phones = [optional_silence] # extract the nonsil_left and nonsil_right segments, and then try to # append nonsil_left to the pron candidates of preceding word, getting # extended pron candidates. # Note: the ctm_pron file may have cases like: # KevinStone_2010U-0024782-0025580-1 [UH] EH # KevinStone_2010U-0024782-0025580-1 fda F T # KevinStone_2010U-0024782-0025580-1 [NOISE] IY EY # which means non-scored-words (except oov symbol <unk>/<UNK>) behaves like <eps>. # So we apply the same merging method in these cases. if word == '<eps>' or (word in non_scored_words and word != '<unk>' and word != '<UNK>'): nonsil_left = [] nonsil_right = [] for phone in phones: if phone in silphones: break nonsil_left.append(phone) for phone in reversed(phones): if phone in silphones: break nonsil_right.insert(0, phone) # info[-1][0] is the utt_id of the last entry if len(nonsil_left) > 0 and len(info) > 0 and utt == info[-1][0]: # pron_ext is a set of extended pron candidates. pron_ext = set() # info[-1][2] is the set of pron candidates of the last entry. for pron in info[-1][2]: # skip generating the extended pron candidate if # the pron ends with a silphone. ends_with_sil = False for sil in silphones: if pron.endswith(sil): ends_with_sil = True if not ends_with_sil: pron_ext.add(pron+" "+" ".join(nonsil_left)) if isinstance(info[-1][2], set): info[-1][2] = info[-1][2].union(pron_ext) if len(nonsil_right) > 0: info.append([utt, word, " ".join(nonsil_right)]) else: prons = set() prons.add(" ".join(phones)) # If there's a preceding <eps>/non_scored_words (which means the third field is a string rather than a set of strings), # we append it's nonsil_right segment to the pron candidates of the current word. if len(info) > 0 and utt == info[-1][0] and isinstance(info[-1][2], str) and (phones == [] or phones[0] not in silphones): # info[-1][2] is the nonsil_right segment of the phones aligned to the last <eps>/non_scored_words. prons.add(info[-1][2]+' '+" ".join(phones)) info.append([utt, word, prons]) stats = {} for utt, word, prons in info: # If the prons is not a set, the current word must be <eps> or an non_scored_word, # where we just left the nonsil_right part as prons. if isinstance(prons, set) and len(prons) > 0: count = 1.0 / float(len(prons)) for pron in prons: phones = pron.strip().split() # post-processing: remove all begining/trailing silence phones. # we allow only candidates that either consist of a single silence # phone, or the silence phones are inside non-silence phones. if len(phones) > 1: begin = 0 for phone in phones: if phone in silphones: begin += 1 else: break if begin == len(phones): begin -= 1 phones = phones[begin:] if len(phones) == 1: break end = len(phones) for phone in reversed(phones): if phone in silphones: end -= 1 else: break phones = phones[:end] phones = " ".join(phones) stats[(word, phones)] = stats.get((word, phones), 0) + count return stats def WriteStats(stats, file_handle): for word_pron, count in stats.items(): print('{0} {1} {2}'.format(count, word_pron[0], word_pron[1]), file=file_handle) file_handle.close() def Main(): args = GetArgs() silphones = ReadEntries(args.silence_file_handle) non_scored_words = ReadEntries(args.non_scored_words_file_handle) optional_silence = ReadEntries(args.optional_silence_file_handle) stats = GetStatsFromCtmProns(silphones, optional_silence.pop(), non_scored_words, args.ctm_prons_file_handle) WriteStats(stats, args.stats_file_handle) if __name__ == "__main__": Main() |