Blame view

egs/wsj/s5/utils/lang/limit_arpa_unk_history.py 5.35 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
  #!/usr/bin/env python3
  
  # Copyright 2018    Armin Oliya
  # Apache 2.0.
  
  '''
  This script takes an existing ARPA lanugage model and limits the <unk> history
  to make it suitable for downstream <unk> modeling.
  This is for the case when you don't have access
  to the original text corpus that is used for creating the LM.
  If you do, you can use pocolm with the option --limit-unk-history=true.
  This keeps the graph compact after adding the unk model.
  '''
  
  import argparse
  import io
  import re
  import sys
  from collections import defaultdict
  
  
  parser = argparse.ArgumentParser(
      description='''This script takes an existing ARPA lanugage model
      and limits the <unk> history to make it suitable
      for downstream <unk> modeling.
      It supports up to 5-grams.''',
      usage='''utils/lang/limit_arpa_unk_history.py
      <oov-dict-entry> <input-arpa >output-arpa''',
      epilog='''E.g.: gunzip -c src.arpa.gz |
      utils/lang/limit_arpa_unk_history.py "<unk>" | gzip -c >dest.arpa.gz''')
  
  parser.add_argument(
      'oov_dict_entry',
      help='oov identifier, for example "<unk>"', type=str)
  args = parser.parse_args()
  
  
  def get_ngram_stats(old_lm_lines):
      ngram_counts = defaultdict(int)
  
      for i in range(10):
          g = re.search(r"ngram (\d)=(\d+)", old_lm_lines[i])
          if g:
              ngram_counts[int(g.group(1))] = int(g.group(2))
  
      if len(ngram_counts) == 0:
          sys.exit("""Couldn't get counts per ngram section.
              The input doesn't seem to be a valid ARPA language model.""")
  
      max_ngrams = list(ngram_counts.keys())[-1]
      skip_rows = ngram_counts[1]
  
      if max_ngrams > 5:
          sys.exit("This script supports up to 5-gram language models.")
  
      return max_ngrams, skip_rows, ngram_counts
  
  
  def find_and_replace_unks(old_lm_lines, max_ngrams, skip_rows):
      ngram_diffs = defaultdict(int)
      whitespace_pattern = re.compile("[ \t]+")
      unk_pattern = re.compile(
          "[0-9.-]+(?:[\s\\t]\S+){1,3}[\s\\t]" + args.oov_dict_entry +
          "[\s\\t](?!-[0-9]+\.[0-9]+).*")
      backoff_pattern = re.compile(
          "[0-9.-]+(?:[\s\\t]\S+){1,3}[\s\\t]<unk>[\s\\t]-[0-9]+\.[0-9]+")
      passed_2grams, last_ngram = False, False
      unk_row_count, backoff_row_count = 0, 0
  
      print("Upadting the language model .. ", file=sys.stderr)
      new_lm_lines = old_lm_lines[:skip_rows]
  
      for i in range(skip_rows, len(old_lm_lines)):
              line = old_lm_lines[i].strip(" \t\r
  ")
  
              if "\{}-grams:".format(3) in line:
                  passed_2grams = True
              if "\{}-grams:".format(max_ngrams) in line:
                  last_ngram = True
  
              for i in range(max_ngrams):
                  if "\{}-grams:".format(i+1) in line:
                      ngram = i+1
  
              # remove any n-gram states of the form: foo <unk> -> X
              # that is, any n-grams of order > 2 where <unk>
              # is the second-to-last word
              # here we skip 1-gram and 2-gram sections of arpa
  
              if passed_2grams:
                  g_unk = unk_pattern.search(line)
                  if g_unk:
                      ngram_diffs[ngram] = ngram_diffs[ngram] - 1
                      unk_row_count += 1
                      continue
  
              # remove backoff probability from the lines that end with <unk>
              # for example, the -0.64 in -4.09 every <unk> -0.64
              # here we skip the last n-gram section because it
              # doesn't include backoff probabilities
  
              if not last_ngram:
                  g_backoff = backoff_pattern.search(line)
                  if g_backoff:
                      updated_row = whitespace_pattern.split(g_backoff.group(0))[:-1]
                      updated_row = updated_row[0] + \
                          "\t" + " ".join(updated_row[1:]) + "
  "
                      new_lm_lines.append(updated_row)
                      backoff_row_count += 1
                      continue
  
              new_lm_lines.append(line+"
  ")
  
      print("Removed {} lines including {} as second-to-last term.".format(
          unk_row_count, args.oov_dict_entry), file=sys.stderr)
      print("Removed backoff probabilties from {} lines.".format(
          backoff_row_count), file=sys.stderr)
  
      return new_lm_lines, ngram_diffs
  
  
  def read_old_lm():
      print("Reading ARPA LM frome input stream .. ", file=sys.stderr)
  
      with io.TextIOWrapper(
              sys.stdin.buffer,
              encoding="latin-1") as input_stream:
          old_lm_lines = input_stream.readlines()
  
      return old_lm_lines
  
  
  def write_new_lm(new_lm_lines, ngram_counts, ngram_diffs):
      ''' Update n-gram counts that go in the header of the arpa lm '''
  
      for i in range(10):
          g = re.search(r"ngram (\d)=(\d+)", new_lm_lines[i])
          if g:
              n = int(g.group(1))
              if n in ngram_diffs:
                  # ngram_diffs contains negative values
                  new_num_ngrams = ngram_counts[n] + ngram_diffs[n]
                  new_lm_lines[i] = "ngram {}={}
  ".format(
                      n, new_num_ngrams)
  
      with io.TextIOWrapper(
              sys.stdout.buffer,
              encoding="latin-1") as output_stream:
          output_stream.writelines(new_lm_lines)
  
  
  def main():
      old_lm_lines = read_old_lm()
      max_ngrams, skip_rows,  ngram_counts = get_ngram_stats(old_lm_lines)
      new_lm_lines, ngram_diffs = find_and_replace_unks(
          old_lm_lines, max_ngrams, skip_rows)
      write_new_lm(new_lm_lines, ngram_counts, ngram_diffs)
  
  
  if __name__ == "__main__":
      main()