Blame view
egs/wsj/s5/utils/lang/make_kn_lm.py
16 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 |
#!/usr/bin/env python3 # Copyright 2016 Johns Hopkins University (Author: Daniel Povey) # 2018 Ruizhe Huang # Apache 2.0. # This is an implementation of computing Kneser-Ney smoothed language model # in the same way as srilm. This is a back-off, unmodified version of # Kneser-Ney smoothing, which produces the same results as the following # command (as an example) of srilm: # # $ ngram-count -order 4 -kn-modify-counts-at-end -ukndiscount -gt1min 0 -gt2min 0 -gt3min 0 -gt4min 0 \ # -text corpus.txt -lm lm.arpa # # The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py # The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html import sys import os import re import io import math import argparse from collections import Counter, defaultdict parser = argparse.ArgumentParser(description=""" Generate kneser-ney language model as arpa format. By default, it will read the corpus from standard input, and output to standard output. """) parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram") parser.add_argument("-text", type=str, default=None, help="Path to the corpus file") parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models") parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level") args = parser.parse_args() default_encoding = "latin-1" # For encoding-agnostic scripts, we assume byte stream as input. # Need to be very careful about the use of strip() and split() # in this case, because there is a latin-1 whitespace character # (nbsp) which is part of the unicode encoding range. # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717 strip_chars = " \t\r " whitespace = re.compile("[ \t]+") class CountsForHistory: # This class (which is more like a struct) stores the counts seen in a # particular history-state. It is used inside class NgramCounts. # It really does the job of a dict from int to float, but it also # keeps track of the total count. def __init__(self): # The 'lambda: defaultdict(float)' is an anonymous function taking no # arguments that returns a new defaultdict(float). self.word_to_count = defaultdict(int) self.word_to_context = defaultdict(set) # using a set to count the number of unique contexts self.word_to_f = dict() # discounted probability self.word_to_bow = dict() # back-off weight self.total_count = 0 def words(self): return self.word_to_count.keys() def __str__(self): # e.g. returns ' total=12: 3->4, 4->6, -1->2' return ' total={0}: {1}'.format( str(self.total_count), ', '.join(['{0} -> {1}'.format(word, count) for word, count in self.word_to_count.items()])) def add_count(self, predicted_word, context_word, count): assert count >= 0 self.total_count += count self.word_to_count[predicted_word] += count if context_word is not None: self.word_to_context[predicted_word].add(context_word) class NgramCounts: # A note on data-structure. Firstly, all words are represented as # integers. We store n-gram counts as an array, indexed by (history-length # == n-gram order minus one) (note: python calls arrays "lists") of dicts # from histories to counts, where histories are arrays of integers and # "counts" are dicts from integer to float. For instance, when # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict. def __init__(self, ngram_order, bos_symbol='<s>', eos_symbol='</s>'): assert ngram_order >= 2 self.ngram_order = ngram_order self.bos_symbol = bos_symbol self.eos_symbol = eos_symbol self.counts = [] for n in range(ngram_order): self.counts.append(defaultdict(lambda: CountsForHistory())) self.d = [] # list of discounting factor for each order of ngram # adds a raw count (called while processing input data). # Suppose we see the sequence '6 7 8 9' and ngram_order=4, 'history' # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be # 1. def add_count(self, history, predicted_word, context_word, count): self.counts[len(history)][history].add_count(predicted_word, context_word, count) # 'line' is a string containing a sequence of integer word-ids. # This function adds the un-smoothed counts from this line of text. def add_raw_counts_from_line(self, line): words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol] for i in range(len(words)): for n in range(1, self.ngram_order+1): if i + n > len(words): break ngram = words[i: i + n] predicted_word = ngram[-1] history = tuple(ngram[: -1]) if i == 0 or n == self.ngram_order: context_word = None else: context_word = words[i-1] self.add_count(history, predicted_word, context_word, 1) def add_raw_counts_from_standard_input(self): lines_processed = 0 infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding) # byte stream as input for line in infile: line = line.strip(strip_chars) if line == '': break self.add_raw_counts_from_line(line) lines_processed += 1 if lines_processed == 0 or args.verbose > 0: print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) def add_raw_counts_from_file(self, filename): lines_processed = 0 with open(filename, encoding=default_encoding) as fp: for line in fp: line = line.strip(strip_chars) if line == '': break self.add_raw_counts_from_line(line) lines_processed += 1 if lines_processed == 0 or args.verbose > 0: print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) def cal_discounting_constants(self): # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N), # where n1_N is the number of unique N-grams with count = 1 (counts-of-counts). # This constant is used similarly to absolute discounting. # Return value: d is a list of floats, where d[N+1] = D_N self.d = [0] # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0 # This is a special case: as we currently assumed having seen all vocabularies in the dictionary, # but perhaps this is not the case for some other scenarios. for n in range(1, self.ngram_order): this_order_counts = self.counts[n] n1 = 0 n2 = 0 for hist, counts_for_hist in this_order_counts.items(): stat = Counter(counts_for_hist.word_to_count.values()) n1 += stat[1] n2 += stat[2] assert n1 + 2 * n2 > 0 self.d.append(n1 * 1.0 / (n1 + 2 * n2)) def cal_f(self): # f(a_z) is a probability distribution of word sequence a_z. # Typically f(a_z) is discounted to be less than the ML estimate so we have # some leftover probability for the z words unseen in the context (a_). # # f(a_z) = (c(a_z) - D0) / c(a_) ;; for highest order N-grams # f(_z) = (n(*_z) - D1) / n(*_*) ;; for lower order N-grams # highest order N-grams n = self.ngram_order - 1 this_order_counts = self.counts[n] for hist, counts_for_hist in this_order_counts.items(): for w, c in counts_for_hist.word_to_count.items(): counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count # lower order N-grams for n in range(0, self.ngram_order - 1): this_order_counts = self.counts[n] for hist, counts_for_hist in this_order_counts.items(): n_star_star = 0 for w in counts_for_hist.word_to_count.keys(): n_star_star += len(counts_for_hist.word_to_context[w]) if n_star_star != 0: for w in counts_for_hist.word_to_count.keys(): n_star_z = len(counts_for_hist.word_to_context[w]) counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star else: # patterns begin with <s>, they do not have "modified count", so use raw count instead for w in counts_for_hist.word_to_count.keys(): n_star_z = counts_for_hist.word_to_count[w] counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count def cal_bow(self): # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram. # Thus, two sorts of ngrams do not have a bow: # 1) highest order ngram # 2) ngrams ending in </s> # # bow(a_) = (1 - Sum_Z1 f(a_z)) / (1 - Sum_Z1 f(_z)) # Note that Z1 is the set of all words with c(a_z) > 0 # highest order N-grams n = self.ngram_order - 1 this_order_counts = self.counts[n] for hist, counts_for_hist in this_order_counts.items(): for w in counts_for_hist.word_to_count.keys(): counts_for_hist.word_to_bow[w] = None # lower order N-grams for n in range(0, self.ngram_order - 1): this_order_counts = self.counts[n] for hist, counts_for_hist in this_order_counts.items(): for w in counts_for_hist.word_to_count.keys(): if w == self.eos_symbol: counts_for_hist.word_to_bow[w] = None else: a_ = hist + (w,) assert len(a_) < self.ngram_order assert a_ in self.counts[len(a_)].keys() a_counts_for_hist = self.counts[len(a_)][a_] sum_z1_f_a_z = 0 for u in a_counts_for_hist.word_to_count.keys(): sum_z1_f_a_z += a_counts_for_hist.word_to_f[u] sum_z1_f_z = 0 _ = a_[1:] _counts_for_hist = self.counts[len(_)][_] for u in a_counts_for_hist.word_to_count.keys(): # Should be careful here: what is Z1 sum_z1_f_z += _counts_for_hist.word_to_f[u] counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z) def print_raw_counts(self, info_string): # these are useful for debug. print(info_string) res = [] for this_order_counts in self.counts: for hist, counts_for_hist in this_order_counts.items(): for w in counts_for_hist.word_to_count.keys(): ngram = " ".join(hist) + " " + w ngram = ngram.strip(strip_chars) res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])) res.sort(reverse=True) for r in res: print(r) def print_modified_counts(self, info_string): # these are useful for debug. print(info_string) res = [] for this_order_counts in self.counts: for hist, counts_for_hist in this_order_counts.items(): for w in counts_for_hist.word_to_count.keys(): ngram = " ".join(hist) + " " + w ngram = ngram.strip(strip_chars) modified_count = len(counts_for_hist.word_to_context[w]) raw_count = counts_for_hist.word_to_count[w] if modified_count == 0: res.append("{0}\t{1}".format(ngram, raw_count)) else: res.append("{0}\t{1}".format(ngram, modified_count)) res.sort(reverse=True) for r in res: print(r) def print_f(self, info_string): # these are useful for debug. print(info_string) res = [] for this_order_counts in self.counts: for hist, counts_for_hist in this_order_counts.items(): for w in counts_for_hist.word_to_count.keys(): ngram = " ".join(hist) + " " + w ngram = ngram.strip(strip_chars) f = counts_for_hist.word_to_f[w] if f == 0: # f(<s>) is always 0 f = 1e-99 res.append("{0}\t{1}".format(ngram, math.log(f, 10))) res.sort(reverse=True) for r in res: print(r) def print_f_and_bow(self, info_string): # these are useful for debug. print(info_string) res = [] for this_order_counts in self.counts: for hist, counts_for_hist in this_order_counts.items(): for w in counts_for_hist.word_to_count.keys(): ngram = " ".join(hist) + " " + w ngram = ngram.strip(strip_chars) f = counts_for_hist.word_to_f[w] if f == 0: # f(<s>) is always 0 f = 1e-99 bow = counts_for_hist.word_to_bow[w] if bow is None: res.append("{1}\t{0}".format(ngram, math.log(f, 10))) else: res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10))) res.sort(reverse=True) for r in res: print(r) def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')): # print as ARPA format. print('\\data\\', file=fout) for hist_len in range(self.ngram_order): # print the number of n-grams. print('ngram {0}={1}'.format( hist_len + 1, sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])), file=fout ) print('', file=fout) for hist_len in range(self.ngram_order): print('\\{0}-grams:'.format(hist_len + 1), file=fout) this_order_counts = self.counts[hist_len] for hist, counts_for_hist in this_order_counts.items(): for word in counts_for_hist.word_to_count.keys(): ngram = hist + (word,) prob = counts_for_hist.word_to_f[word] bow = counts_for_hist.word_to_bow[word] if prob == 0: # f(<s>) is always 0 prob = 1e-99 line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram)) if bow is not None: line += '\t{0}'.format('%.7f' % math.log10(bow)) print(line, file=fout) print('', file=fout) print('\\end\\', file=fout) if __name__ == "__main__": ngram_counts = NgramCounts(args.ngram_order) if args.text is None: ngram_counts.add_raw_counts_from_standard_input() else: assert os.path.isfile(args.text) ngram_counts.add_raw_counts_from_file(args.text) ngram_counts.cal_discounting_constants() ngram_counts.cal_f() ngram_counts.cal_bow() if args.lm is None: ngram_counts.print_as_arpa() else: with open(args.lm, 'w', encoding=default_encoding) as f: ngram_counts.print_as_arpa(fout=f) |