Blame view

egs/wsj/s5/steps/cleanup/internal/make_one_biased_lm.py 14.9 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
  #!/usr/bin/env python3
  
  # Copyright 2016  Johns Hopkins University (Author: Daniel Povey)
  # Apache 2.0.
  
  from __future__ import print_function
  from __future__ import division
  import sys
  import argparse
  import math
  from collections import defaultdict
  
  import io
  sys.stdout = io.TextIOWrapper(sys.stdout.buffer,encoding="utf8")
  sys.stderr = io.TextIOWrapper(sys.stderr.buffer,encoding="utf8")
  sys.stdin = io.TextIOWrapper(sys.stdin.buffer,encoding="utf8")
  
  parser = argparse.ArgumentParser(description="""
  This script creates a biased language model suitable for alignment and
  data-cleanup purposes.   It reads (possibly multiple) lines of integerized text
  from the input and writes a text-form FST of a backoff language model to
  the standard output, to be piped into fstcompile.""")
  
  parser.add_argument("--word-disambig-symbol", type = int, required = True,
                      help = "Integer corresponding to the disambiguation "
                      "symbol (normally #0) for backoff arcs")
  parser.add_argument("--ngram-order", type = int, default = 4,
                      choices = [2,3,4,5,6,7],
                      help = "Maximum order of n-gram to use (but see also "
                      "--min-lm-state-count; the effective order may be less.")
  parser.add_argument("--min-lm-state-count", type = int, default = 10,
                      help = "Minimum count below which we will completely "
                      "discount an LM-state (if it is of order > 2, i.e. "
                      "history-length > 1).")
  parser.add_argument("--top-words", type = str,
                      help = "File containing frequent words and probabilities to be added into "
                      "the language model, with lines in the format '<integer-id-of-word> <prob>'. "
                      "These probabilities will be added to the probabilities in the unigram "
                      "backoff state and then renormalized; this option allows you to introduce "
                      "common words to the LM with specified probabilities.")
  parser.add_argument("--discounting-constant", type = float, default = 0.3,
                      help = "Discounting constant D for standard (unmodified) Kneser-Ney; "
                      "must be strictly between 0 and 1.  A value closer to 0 will give "
                      "you a more-strongly-biased LM.")
  parser.add_argument("--verbose", type = int, default = 0,
                      choices=[0,1,2,3,4,5], help = "Verbose level")
  
  args = parser.parse_args()
  
  if args.verbose >= 1:
      print(' '.join(sys.argv), file = sys.stderr)
  
  
  
  
  class NgramCounts(object):
      ## A note on data-structure.
      ## Firstly, all words are represented as integers.
      ## We store n-gram counts as an array, indexed by (history-length == n-gram order minus one)
      ## (note: python calls arrays "lists")  of dicts from histories to counts, where
      ## histories are arrays of integers and "counts" are dicts from integer to float.
      ## For instance, when accumulating the 4-gram count for the '8' in the sequence '5 6 7 8',
      ## we'd do as follows:
      ##  self.counts[3][[5,6,7]][8] += 1.0
      ## where the [3] indexes an array, the [[5,6,7]] indexes a dict, and
      ## the [8] indexes a dict.
      def __init__(self, ngram_order):
          self.ngram_order = ngram_order
          # Integerized counts will never contain negative numbers, so
          # inside this program, we use -3 and -2 for the BOS and EOS symbols
          # respectively.
          # Note: it's actually important that the bos-symbol is the most negative;
          # it helps ensure that we print the state with left-context <s> first
          # when we print the FST, and this means that the start-state will have
          # the correct value.
          self.bos_symbol = -3
          self.eos_symbol = -2
          # backoff_symbol is kind of a pseudo-word, it's used in keeping track of
          # the backoff counts in each state.
          self.backoff_symbol = -1
          self.counts = []
          for n in range(ngram_order):
              # The 'lambda: defaultdict(float)' is an anonymous function taking
              # no arguments that returns a new defaultdict(float).
              # If we index self.counts[n][history] for a history-length n < ngram_order
              # and a previously unseen history, it will create a new defaultdict
              # that defaults to 0.0 [since the function float() will return 0.0].
              # This means that we can index self.counts without worrying about
              # undefined values.
              self.counts.append(defaultdict(lambda: defaultdict(float)))
  
      # adds a raw count (called while processing input data).
      # Suppose we see the sequence '6 7 8 9' and ngram_order=4, 'history'
      # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
      # 1.0.
      def AddCount(self, history, predicted_word, count):
          self.counts[len(history)][history][predicted_word] += count
  
      # 'line' is a string containing a sequence of integer word-ids.
      # This function adds the un-smoothed counts from this line of text.
      def AddRawCountsFromLine(self, line):
          try:
              words = [self.bos_symbol] + [ int(x) for x in line.split() ] + [self.eos_symbol]
          except:
              sys.exit("make_one_biased_lm.py: bad input line {0} (expected a sequence "
                       "of integers)".format(line))
  
          for n in range(1, len(words)):
              predicted_word = words[n]
              history_start = max(0, n + 1 - self.ngram_order)
              history = tuple(words[history_start:n])
              self.AddCount(history, predicted_word, 1.0)
  
      def AddRawCountsFromStandardInput(self):
          lines_processed = 0
          while True:
              line = sys.stdin.readline()
              if line == '':
                  break
              self.AddRawCountsFromLine(line)
              lines_processed += 1
          if lines_processed == 0 or args.verbose > 0:
              print("make_one_biased_lm.py: processed {0} lines of input".format(
                      lines_processed), file = sys.stderr)
  
  
      # This function returns a dict from history (as a tuple of integers of
      # length > 1, ignoring lower-order histories), to the total count of this
      # history state plus all history-states which back off to this history state.
      # It's used inside CompletelyDiscountLowCountStates().
      def GetHistToTotalCount(self):
          ans = defaultdict(float)
          for n in range(2, self.ngram_order):
              for hist, word_to_count in self.counts[n].items():
                  total_count = sum(word_to_count.values())
                  while len(hist) >= 2:
                      ans[hist] += total_count
                      hist = hist[1:]
          return ans
  
  
      # This function will completely discount the counts in any LM-states of
      # order > 2 (i.e. history-length > 1) that have total count below
      # 'min_count'; when computing the total counts, we include higher-order
      # LM-states that would back off to 'this' lm-state, in the total.
      def CompletelyDiscountLowCountStates(self, min_count):
          hist_to_total_count = self.GetHistToTotalCount()
          for n in reversed(list(range(2, self.ngram_order))):
              this_order_counts = self.counts[n]
              to_delete = []
              for hist in this_order_counts.keys():
                  if hist_to_total_count[hist] < min_count:
                      # we need to completely back off this count.
                      word_to_count = this_order_counts[hist]
                      # mark this key for deleting
                      to_delete.append(hist)
                      backoff_hist = hist[1:]  # this will be a tuple not a list.
                      for word, count in word_to_count.items():
                          self.AddCount(backoff_hist, word, count)
              for hist in to_delete:
                  del this_order_counts[hist]
  
      # This backs off the counts according to Kneser-Ney (unmodified,
      # with interpolation).
      def ApplyBackoff(self, D):
          assert D > 0.0 and D < 1.0
          for n in reversed(list(range(1, self.ngram_order))):
              this_order_counts = self.counts[n]
              for hist, word_to_count in this_order_counts.items():
                  backoff_hist = hist[1:]
                  backoff_word_to_count = self.counts[n-1][backoff_hist]
                  this_discount_total = 0.0
                  for word in word_to_count:
                      assert word_to_count[word] >= 1.0
                      word_to_count[word] -= D
                      this_discount_total += D
                      # Interpret the following line as incrementing the
                      # count-of-counts for the next-lower order.
                      backoff_word_to_count[word] += 1.0
                  word_to_count[self.backoff_symbol] += this_discount_total
  
  
      # This function prints out to stderr the n-gram counts stored in this
      # object; it's used for debugging.
      def Print(self, info_string):
          print(info_string, file=sys.stderr)
          # these are useful for debug.
          total = 0.0
          total_excluding_backoff = 0.0
          for this_order_counts in self.counts:
              for hist, word_to_count in this_order_counts.items():
                  this_total_count = sum(word_to_count.values())
                  print('{0}: total={1} '.format(hist, this_total_count),
                        end='', file=sys.stderr)
                  print(' '.join(['{0} -> {1} '.format(word, count)
                                  for word, count in word_to_count.items() ]),
                        file = sys.stderr)
                  total += this_total_count
                  total_excluding_backoff += this_total_count
                  if self.backoff_symbol in word_to_count:
                      total_excluding_backoff -= word_to_count[self.backoff_symbol]
          print('total count = {0}, excluding discount = {1}'.format(
                  total, total_excluding_backoff), file = sys.stderr)
  
      def AddTopWords(self, top_words_file):
          empty_history = ()
          word_to_count = self.counts[0][empty_history]
          total = sum(word_to_count.values())
          try:
              f = open(top_words_file, mode='r', encoding='utf-8')
          except:
              sys.exit("make_one_biased_lm.py: error opening top-words file: "
                       "--top-words=" + top_words_file)
          while True:
              line = f.readline()
              if line == '':
                  break
              try:
                  [ word_index, prob ] = line.split()
                  word_index = int(word_index)
                  prob = float(prob)
                  assert word_index > 0 and prob > 0.0
                  word_to_count[word_index] += prob * total
              except Exception as e:
                  sys.exit("make_one_biased_lm.py: could not make sense of the "
                           "line '{0}' in op-words file: {1} ".format(line, str(e)))
          f.close()
  
  
      def GetTotalCountMap(self):
          # This function, called from PrintAsFst, returns a map from
          # history to the total-count for that state.
          total_count_map = dict()
          for n in range(0, self.ngram_order):
              for hist, word_to_count in self.counts[n].items():
                  total_count_map[hist] = sum(word_to_count.values())
          return total_count_map
  
      def GetHistToStateMap(self):
          # This function, called from PrintAsFst, returns a map from
          # history to integer FST-state.
          hist_to_state = dict()
          fst_state_counter = 0
          for n in range(0, self.ngram_order):
              for hist in self.counts[n].keys():
                  hist_to_state[hist] = fst_state_counter
                  fst_state_counter += 1
          return hist_to_state
  
      def GetProb(self, hist, word, total_count_map):
          total_count = total_count_map[hist]
          word_to_count = self.counts[len(hist)][hist]
          prob = float(word_to_count[word]) / total_count
          if len(hist) > 0 and word != self.backoff_symbol:
              prob_in_backoff = self.GetProb(hist[1:], word, total_count_map)
              backoff_prob = float(word_to_count[self.backoff_symbol]) / total_count
              prob += backoff_prob * prob_in_backoff
          return prob
  
      # This function prints the estimated language model as an FST.
      def PrintAsFst(self, word_disambig_symbol):
          # n is the history-length (== order + 1).  We iterate over the
          # history-length in the order 1, 0, 2, 3, and then iterate over the
          # histories of each order in sorted order.  Putting order 1 first
          # and sorting on the histories
          # ensures that the bigram state with <s> as the left context comes first.
          # (note: self.bos_symbol is the most negative symbol)
  
          # History will map from history (as a tuple) to integer FST-state.
          hist_to_state = self.GetHistToStateMap()
          total_count_map = self.GetTotalCountMap()
  
          for n in [ 1, 0 ] + list(range(2, self.ngram_order)):
              this_order_counts = self.counts[n]
              # For order 1, make sure the keys are sorted.
              keys = this_order_counts.keys() if n != 1 else sorted(this_order_counts.keys())
              for hist in keys:
                  word_to_count = this_order_counts[hist]
                  this_fst_state = hist_to_state[hist]
  
                  for word in word_to_count.keys():
                      # work out this_cost.  Costs in OpenFst are negative logs.
                      this_cost = -math.log(self.GetProb(hist, word, total_count_map))
  
                      if word > 0: # a real word.
                          next_hist = hist + (word,)  # appending tuples
                          while not next_hist in hist_to_state:
                              next_hist = next_hist[1:]
                          next_fst_state = hist_to_state[next_hist]
                          print(this_fst_state, next_fst_state, word, word,
                                this_cost)
                      elif word == self.eos_symbol:
                          # print final-prob for this state.
                          print(this_fst_state, this_cost)
                      else:
                          assert word == self.backoff_symbol
                          backoff_fst_state = hist_to_state[hist[1:len(hist)]]
                          print(this_fst_state, backoff_fst_state,
                                word_disambig_symbol, 0, this_cost)
  
  
  ngram_counts = NgramCounts(args.ngram_order)
  ngram_counts.AddRawCountsFromStandardInput()
  
  if args.verbose >= 3:
      ngram_counts.Print("Raw counts:")
  ngram_counts.CompletelyDiscountLowCountStates(args.min_lm_state_count)
  if args.verbose >= 3:
      ngram_counts.Print("Counts after discounting low-count states:")
  ngram_counts.ApplyBackoff(args.discounting_constant)
  if args.verbose >= 3:
      ngram_counts.Print("Counts after applying Kneser-Ney discounting:")
  if args.top_words != None:
      ngram_counts.AddTopWords(args.top_words)
      if args.verbose >= 3:
          ngram_counts.Print("Counts after applying top-n-words")
  ngram_counts.PrintAsFst(args.word_disambig_symbol)
  
  
  # test comand:
  # (echo 6 7 8 4; echo 7 8 9; echo 7 8) | ./make_one_biased_lm.py --word-disambig-symbol=1000 --min-lm-state-count=2 --verbose=3 --top-words=<(echo 1 0.5; echo 2 0.25)