Blame view
egs/wsj/s5/utils/lang/internal/arpa2fst_constrained.py
18.1 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 |
#!/usr/bin/env python # Copyright 2016 Johns Hopkins University (Author: Daniel Povey) # Apache 2.0. from __future__ import print_function from __future__ import division import sys import argparse import math from collections import defaultdict # note, this was originally based parser = argparse.ArgumentParser(description=""" This script converts an ARPA-format language model to FST format (like the C++ program arpa2fst), but does so while applying bigram constraints supplied in a separate file. The resulting language model will have no unigram state, and there will be no backoff from the bigram level. This is useful for phone-level language models in order to keep graphs small and impose things like linguistic constraints on allowable phone sequences. This script writes its output to the stdout. It is a text-form FST, suitable for compilation by fstcompile. """) parser.add_argument('--disambig-symbol', type = str, default = "#0", help = 'Disambiguation symbol (e.g. #0), ' 'that is printed on the input side only of backoff ' 'arcs (output side would be epsilon)') parser.add_argument('arpa_in', type = str, help = 'The input ARPA file (must not be gzipped)') parser.add_argument('allowed_bigrams_in', type = str, help = "A file containing the list of allowed bigram pairs. " "Must include pairs like '<s> foo' and 'foo </s>', as well as " "pairs like 'foo bar'.") parser.add_argument('--verbose', type = int, default = 0, choices=[0,1,2,3,4,5], help = 'Verbose level') args = parser.parse_args() if args.verbose >= 1: print(' '.join(sys.argv), file = sys.stderr) class HistoryState(object): def __init__(self): # note: neither backoff_prob nor the floats # in word_to_prob are in log space. self.backoff_prob = 1.0 # will be a dict from string to float. the prob is # the actual probability of the word, including any probability # mass from backoff (they get added together while writing out # the arpa, and these probs are read in from the arpa). self.word_to_prob = dict() class ArpaModel(object): def __init__(self): # self.orders is indexed by history-length [i.e. 0 for unigram, # 1 for bigram and so on], and is then a dict indexed # by tuples of history-words. E.g. for trigrams, we'd index # it as self.orders[2][('a', 'b')]. # The value-type of the dict is HistoryState. E.g. to set the # probability of the trigram a b -> c to 0.2, we'd do # self.orders[2][('a', 'b')].word_to_prob['c'] = 0.2 self.orders = [] def Read(self, arpa_in): assert len(self.orders) == 0 log10 = math.log(10.0) if arpa_in == "" or arpa_in == "-": arpa_in = "/dev/stdin" try: f = open(arpa_in, "r") except: sys.exit("{0}: error opening ARPA file {1}".format( sys.argv[0], arpa_in)) # first read till the \data\ marker. while True: line = f.readline() if line == '': sys.exit("{0}: reading {1}, got EOF looking for \\data\\ marker.".format( sys.argv[0], arpa_in)) if line[0:6] == '\\data\\': break while True: # read, and ignore, the lines like 'ngram 1=1264'... line = f.readline() if line == ' ' or line == '\r ': break if line[0:5] != 'ngram': sys.exit("{0}: reading {1}, read something unexpected in header: {2}".format( sys.argv[0], arpa_in, line[:-1])) rest=line[5:] a = rest.split('=') # e.g. a = [ '1', '1264] ] if len(a) != 2: sys.exit("{0}: reading {1}, read something unexpected in header: {2}".format( sys.argv[0], arpa_in, line[:-1])) max_order = int(a[0]) for n in range(max_order): # self.orders[n], indexed by history-length (length of the # history-vector, == order-1), is a map from history as a tuple # of strings, to class HistoryState. self.orders.append(defaultdict(lambda: HistoryState())) cur_order = 0 while True: line = f.readline() if line == '': sys.exit("{0}: reading {1}, found EOF while looking for \\end\\ marker.".format( sys.argv[0], arpa_in)) elif line[0:5] == '\\end\\': if len(self.orders) == 0: sys.exit("{0}: reading {1}, read no n-grams.".format(sys.argv[0], arpa_in)) break else: cur_order += 1 expected_line = '\\{0}-grams:'.format(cur_order) if not expected_line in line: # e.g. allow trailing whitespace and newline sys.exit("{0}: reading {1}, expected line {1}, got {2}".format(arpa_in, expected_line, line[:-1])) if args.verbose >= 2: print("{0}: reading {1}-grams".format( sys.argv[0], cur_order), file = sys.stderr) # now read all the n-grams from this order. while True: line = f.readline() # the section of n-grams is terminated by a blank line. if line == ' ' or line == '\r ': break a = line.split() l = len(a) if l != cur_order + 1 and l != cur_order + 2: sys.exit("{0}: reading {1}: in {2}-grams section, got bad line: {3}".format( sys.argv[0], arpa_in, cur_order, line[:-1])) try: prob = math.exp(float(a[0]) * log10) hist = tuple(a[1:cur_order]) # tuple of strings word = a[cur_order] # a string backoff_prob = math.exp(float(a[cur_order+1]) * log10) if l == cur_order + 2 else None except Exception as e: sys.exit("{0}: reading {1}: in {2}-grams section, got bad " "line (exception is: {3}): {4}".format( sys.argv[0], arpa_in, cur_order, str(type(e)) + ',' + str(e), line[:-1])) self.orders[cur_order-1][hist].word_to_prob[word] = prob if backoff_prob != None: self.orders[cur_order][hist + (word,)].backoff_prob = backoff_prob if args.verbose >= 2: print("{0}: read {1}-gram model from {2}".format( sys.argv[0], cur_order, arpa_in), file = sys.stderr) if cur_order < 2: # we'd have to have some if-statements in the code to make this work, # and I don't want to have to test it. sys.exit("{0}: this script does not work when the ARPA language model " "is unigram.".format(sys.argv[0])) # Returns the probability of word 'word' in history-state 'hist'. # Dies with error if this word is not predicted at all by the LM (not in vocab). # history-state does not exist. def GetProb(self, hist, word): assert len(hist) < len(self.orders) if len(hist) == 0: word_to_prob = self.orders[0][()].word_to_prob if not word in word_to_prob: sys.exit("{0}: no probability in unigram for word {1}".format( sys.argv[0], word)) return word_to_prob[word] else: if hist in self.orders[len(hist)]: hist_state = self.orders[len(hist)][hist] if word in hist_state.word_to_prob: return hist_state.word_to_prob[word] else: return hist_state.backoff_prob * self.GetProb(hist[1:], word) else: return self.GetProb(hist[1:], word) # This gets the state corresponding to 'hist' in 'hist_to_state', but backs # off for us if there is no such state. def GetStateForHist(self, hist_to_state, hist): if hist in hist_to_state: return hist_to_state[hist] else: if len(hist) <= 1: # this would likely be a code error, but possibly an error # in the ARPA file sys.exit("{0}: error processing histories: history-state {1} " "does not exist.".format(sys.argv[0], hist)) return self.GetStateForHist(hist_to_state, hist[1:]) def GetHistToStateMap(self): # This function, called from PrintAsFst, returns (hist_to_state, # state_to_hist), which map from history (as a tuple of strings) to # integer FST-state and vice versa. hist_to_state = dict() state_to_hist = [] # Make sure the initial bigram state comes first (and that # we have such a state even if it was completely pruned # away in the bigram LM.. which is unlikely of course) hist = ('<s>',) hist_to_state[hist] = len(state_to_hist) state_to_hist.append(hist) # create a bigram state for each of the 'real' words... even if the LM # didn't naturally have such bigram states, we'll create them so that we # can enforce the bigram constraints supplied in 'bigrams_file' by the # user. for word in self.orders[0][()].word_to_prob: if word != '<s>' and word != '</s>': hist = (word,) hist_to_state[hist] = len(state_to_hist) state_to_hist.append(hist) # note: we do not allocate an FST state for the unigram state, because # we don't have a unigram state in the output FST, only bigram states; and # we don't iterate over bigram histories because we covered them all above; # that's why we start 'n' from 2 below instead of from 0. for n in range(2, len(self.orders)): for hist in self.orders[n].keys(): # note: hist is a tuple of strings. assert not hist in hist_to_state hist_to_state[hist] = len(state_to_hist) state_to_hist.append(hist) return (hist_to_state, state_to_hist) # This function prints the estimated language model as an FST. # disambig_symbol will be something like '#0' (a symbol introduced # to make the result determinizable). # bigram_map represent the allowed bigrams (left-word, right-word): it's a map # from left-word to a set of right-words (both are strings). def PrintAsFst(self, disambig_symbol, bigram_map): # History will map from history (as a tuple) to integer FST-state. (hist_to_state, state_to_hist) = self.GetHistToStateMap() # The following 3 things are just for diagnostics. normalization_stats = [ [0, 0.0] for x in range(len(self.orders)) ] num_ngrams_allowed = 0 num_ngrams_disallowed = 0 for state in range(len(state_to_hist)): hist = state_to_hist[state] hist_len = len(hist) assert hist_len > 0 if hist_len == 1: # it's a bigram state... context_word = hist[0] if not context_word in bigram_map: print("{0}: warning: word {1} appears in ARPA but is not listed " "as a left context in the bigram map".format( sys.argv[0], context_word), file = sys.stderr) continue # word list is a list of words that can follow this word. It must be nonempty. word_list = list(bigram_map[context_word]) normalization_stats[hist_len][0] += 1 for word in word_list: prob = self.GetProb((context_word,), word) assert prob != 0 normalization_stats[hist_len][1] += prob cost = -math.log(prob) if abs(cost) < 0.01 and args.verbose >= 3: print("{0}: warning: very small cost {1} for {2}->{3}".format( sys.argv[0], cost, context_word, word), file=sys.stderr) if word == '</s>': # print the final-prob of this state. print("%d %.3f" % (state, cost)) else: next_state = self.GetStateForHist(hist_to_state, (context_word, word)) print("%d %d %s %s %.3f" % (state, next_state, word, word, cost)) else: # it's a higher-order than bigram state. assert hist in self.orders[hist_len] hist_state = self.orders[hist_len][hist] most_recent_word = hist[-1] normalization_stats[hist_len][0] += 1 normalization_stats[hist_len][1] += \ sum([ self.GetProb(hist, word) for word in bigram_map[most_recent_word]]) for word, prob in hist_state.word_to_prob.items(): cost = -math.log(prob) if word in bigram_map[most_recent_word]: num_ngrams_allowed += 1 else: num_ngrams_disallowed += 1 continue if word == '</s>': # print the final-prob of this state. print("%d %.3f" % (state, cost)) else: next_state = self.GetStateForHist(hist_to_state, (hist) + (word,)) print("%d %d %s %s %.3f" % (state, next_state, word, word, cost)) # Now deal with the backoff probability of this state (back off # to the lower-order state). assert hist in self.orders[hist_len] backoff_prob = self.orders[hist_len][hist].backoff_prob assert backoff_prob != 0.0 cost = -math.log(backoff_prob) backoff_hist = hist[1:] backoff_state = self.GetStateForHist(hist_to_state, backoff_hist) # note: we only print the disambig symbol on the input side. if args.verbose >= 3 and abs(cost) < 0.001: print("{0}: very low backoff cost {1} for history {2}, state = {3}".format( sys.argv[0], cost, str(hist), state), file = sys.stderr) # For hist-states that completely back off (they have no words coming out of them), # there is no need to disambiguate, we can print an epsilon that will later be removed. this_disambig_symbol = disambig_symbol if len(hist_state.word_to_prob) != 0 else '<eps>' print("%d %d %s <eps> %.3f" % (state, backoff_state, this_disambig_symbol, cost)) if args.verbose >= 1: for hist_len in range(1, len(self.orders)): num_states = normalization_stats[hist_len][0] avg_prob_sum = normalization_stats[hist_len][1] / num_states if num_states > 0 else 0.0 print("{0}: for {1}-gram states, over {2} states the average sum of " "probs was {3} (would be 1.0 if properly normalized).".format( sys.argv[0], hist_len + 1, num_states, avg_prob_sum), file = sys.stderr) if num_ngrams_disallowed != 0: print("{0}: for explicit n-grams higher than bigram from the ARPA model, {0} " "were allowed by the bigram constraints and {1} were disallowed (we " "normally expect all or almost all of them to be allowed).".format( num_ngrams_allowed, num_ngrams_disallowed), file = sys.stderr) # returns a map which is a dict [indexed by left-hand word] of sets [containing # the right-hand word]. def ReadBigramMap(bigrams_file): ans = defaultdict(lambda: set()) have_one_bos = False have_one_eos = False have_one_regular = False try: f = open(bigrams_file, "r") except: sys.exit("utils/lang/internal/arpa2fst_constrained.py: error opening " "bigrams file " + bigrams_file) while True: line = f.readline() if line == '': break a = line.split() if len(a) != 2: sys.exit("utils/lang/internal/arpa2fst_constrained.py: bad line in " "bigrams file {0} (expect 2 fields): {1}".format( bigrams_file, line[:-1])) [word1, word2] = a if word1 in ans and word2 in ans[word1]: sys.exit("{0}: bigrams file contained duplicate entry: {1} {2}".format( sys.argv[0], word1, word2), file = sys.stderr) if word2 == '<s>' or word1 == '</s>': sys.exit("{0}: bad sequence of BOS/EOS symbols: {1} {2}".format( sys.argv[0], word1, word2)) if word1 == '<s>': have_one_bos = True elif word2 == '</s>': have_one_eos = True else: have_one_regular = True ans[word1].add(word2) # check for at least one pair with BOS if len(ans) == 0: sys.exit("{0}: no data found in bigrams file {1}".format( sys.argv[0], bigrams_file)) elif not (have_one_bos and have_one_eos and have_one_regular): sys.exit("{0}: the bigrams file {1} does not look right " "(make sure BOS and EOS symbols are there)".format( sys.argv[0], bigrams_file)) return ans arpa_model = ArpaModel() arpa_model.Read(args.arpa_in) bigrams_map = ReadBigramMap(args.allowed_bigrams_in) if len(args.disambig_symbol.split()) != 1: sys.exit("{0}: invalid option --disambig-symbol={1}".format( sys.argv[0], args.disambig_symbol)) arpa_model.PrintAsFst(args.disambig_symbol, bigrams_map) |