Blame view

egs/wsj/s5/steps/dict/prons_to_lexicon.py 8.29 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
  #!/usr/bin/env python
  
  # Copyright 2016  Vimal Manohar
  #           2016  Xiaohui Zhang
  # Apache 2.0.
  
  # we're using python 3.x style print but want it to work in python 2.x,
  from __future__ import print_function
  from collections import defaultdict
  import argparse
  import sys
  
  class StrToBoolAction(argparse.Action):
      """ A custom action to convert bools from shell format i.e., true/false
          to python format i.e., True/False """
      def __call__(self, parser, namespace, values, option_string=None):
          if values == "true":
              setattr(namespace, self.dest, True)
          elif values == "false":
              setattr(namespace, self.dest, False)
          else:
              raise Exception("Unknown value {0} for --{1}".format(values, self.dest))
  
  def GetArgs():
      parser = argparse.ArgumentParser(description = "Converts pronunciation statistics (from phonetic decoding or g2p) "
                                       "into a lexicon for. We prune the pronunciations "
                                       "based on a provided stats file, and optionally filter out entries which are present "
                                       "in a filter lexicon.",
                                       epilog = "e.g. steps/dict/prons_to_lexicon.py --min-prob=0.4 \\"
                                       "--filter-lexicon=exp/tri3_lex_0.4_work/phone_decode/filter_lexicon.txt \\"
                                       "exp/tri3_lex_0.4_work/phone_decode/prons.txt \\"
                                       "exp/tri3_lex_0.4_work/lexicon_phone_decoding.txt"
                                       "See steps/dict/learn_lexicon_greedy.sh for examples in detail.")
  
      parser.add_argument("--set-sum-to-one", type = str, default = False,
                          action = StrToBoolAction, choices = ["true", "false"],
                          help = "If normalize lexicon such that the sum of "
                          "probabilities is 1.")
      parser.add_argument("--set-max-to-one", type = str, default = True,
                          action = StrToBoolAction, choices = ["true", "false"],
                          help = "If normalize lexicon such that the max "
                          "probability is 1.")
      parser.add_argument("--top-N", type = int, default = 0,
                          help = "If non-zero, we just take the top N pronunciations (according to stats/pron-probs) for each word.")
      parser.add_argument("--min-prob", type = float, default = 0.1,
                          help = "Remove pronunciation with probabilities less "
                          "than this value after normalization.")
      parser.add_argument("--filter-lexicon", metavar='<filter-lexicon>', type = str, default = '',
                          help = "Exclude entries in this filter lexicon from the output lexicon."
                          "each line must be <word> <phones>")
      parser.add_argument("stats_file", metavar='<stats-file>', type = str,
                          help = "Input lexicon file containing pronunciation statistics/probs in the first column."
                          "each line must be <counts> <word> <phones>")
      parser.add_argument("out_lexicon", metavar='<out-lexicon>', type = str,
                          help = "Output lexicon.")
  
      print (' '.join(sys.argv), file = sys.stderr)
  
      args = parser.parse_args()
      args = CheckArgs(args)
  
      return args
  
  def CheckArgs(args):
      if args.stats_file == "-":
          args.stats_file_handle = sys.stdin
      else:
          args.stats_file_handle = open(args.stats_file)
  
      if args.filter_lexicon is not '':
          if args.filter_lexicon == "-":
              args.filter_lexicon_handle = sys.stdout
          else:
              args.filter_lexicon_handle = open(args.filter_lexicon)
      
      if args.out_lexicon == "-":
          args.out_lexicon_handle = sys.stdout
      else:
          args.out_lexicon_handle = open(args.out_lexicon, "w")
  
      if args.set_max_to_one == args.set_sum_to_one:
          raise Exception("Cannot have both "
              "set-max-to-one and set-sum-to-one as true or false.")
  
      return args
  
  def ReadStats(args):
      lexicon = {}
      word_count = {}
      for line in args.stats_file_handle:
          splits = line.strip().split()
          if len(splits) < 3:
              continue
  
          word = splits[1]
          count = float(splits[0])
          phones = ' '.join(splits[2:])
  
          lexicon[(word, phones)] = lexicon.get((word, phones), 0) + count
          word_count[word] = word_count.get(word, 0) + count
  
      return [lexicon, word_count]
  
  def ReadLexicon(lexicon_file_handle):
      lexicon = set()
      if lexicon_file_handle:
          for line in lexicon_file_handle.readlines():
              splits = line.strip().split()
              if len(splits) == 0:
                  continue
              if len(splits) < 2:
                  raise Exception('Invalid format of line ' + line
                                      + ' in lexicon file.')
              word = splits[0]
              phones = ' '.join(splits[1:])
              lexicon.add((word, phones))
      return lexicon
  
  def ConvertWordCountsToProbs(args, lexicon, word_count):
      word_probs = {}
      for entry, count in lexicon.iteritems():
          word = entry[0]
          phones = entry[1]
          prob = float(count) / float(word_count[word])
          if word in word_probs:
              word_probs[word].append((phones, prob))
          else:
              word_probs[word] = [(phones, prob)]
  
      return word_probs
  
  def ConvertWordProbsToLexicon(word_probs):
      lexicon = {}
      for word, entry in word_probs.iteritems():
          for x in entry:
              lexicon[(word, x[0])] = lexicon.get((word,x[0]), 0) + x[1]
      return lexicon
  
  def NormalizeLexicon(lexicon, set_max_to_one = True,
                       set_sum_to_one = False, min_prob = 0):
      word_probs = {}
      for entry, prob in lexicon.iteritems():
          t = word_probs.get(entry[0], (0,0))
          word_probs[entry[0]] = (t[0] + prob, max(t[1], prob))
  
      for entry, prob in lexicon.iteritems():
          if set_max_to_one:
              prob = prob / word_probs[entry[0]][1]
          elif set_sum_to_one:
              prob = prob / word_probs[entry[0]][0]
          if prob < min_prob:
              prob = 0
          lexicon[entry] = prob
  
  def TakeTopN(lexicon, top_N):
      lexicon_reshaped = defaultdict(list) 
      lexicon_pruned = {}
      for entry, prob in lexicon.iteritems():
          lexicon_reshaped[entry[0]].append([entry[1], prob])
      for word in lexicon_reshaped:
          prons = lexicon_reshaped[word]
          sorted_prons = sorted(prons, reverse=True, key=lambda prons: prons[1])
          for i in range(len(sorted_prons)):
              if i >= top_N:
                  lexicon[(word, sorted_prons[i][0])] = 0
          
  def WriteLexicon(args, lexicon, filter_lexicon):
      words = set()
      num_removed = 0
      num_filtered = 0
      for entry, prob in lexicon.iteritems():
          if prob == 0:
              num_removed += 1
              continue
          if entry in filter_lexicon:
              num_filtered += 1
              continue
          words.add(entry[0])
          print("{0} {1}".format(entry[0], entry[1]),
                  file = args.out_lexicon_handle)
      print ("Before pruning, the total num. pronunciations is: {}".format(len(lexicon)), file=sys.stderr)
      print ("Removed {0} pronunciations by setting min_prob {1}".format(num_removed, args.min_prob), file=sys.stderr)
      print ("Filtered out {} pronunciations in the filter lexicon.".format(num_filtered), file=sys.stderr)
      num_prons_from_phone_decoding = len(lexicon) - num_removed - num_filtered
      print ("Num. pronunciations in the output lexicon, which solely come from phone decoding"
             "is {0}. num. words is {1}".format(num_prons_from_phone_decoding, len(words)), file=sys.stderr)
  
  def Main():
      args = GetArgs()
  
      [lexicon, word_count] = ReadStats(args)
  
      word_probs = ConvertWordCountsToProbs(args, lexicon, word_count)
  
      lexicon = ConvertWordProbsToLexicon(word_probs)
      filter_lexicon = set()
      if args.filter_lexicon is not '':
          filter_lexicon = ReadLexicon(args.filter_lexicon_handle)
      if args.top_N > 0:
          TakeTopN(lexicon, args.top_N)
      else:
          NormalizeLexicon(lexicon, set_max_to_one = args.set_max_to_one,
                           set_sum_to_one = args.set_sum_to_one,
                           min_prob = args.min_prob)
      WriteLexicon(args, lexicon, filter_lexicon)
      args.out_lexicon_handle.close()
  
  if __name__ == "__main__":
      Main()