Blame view

egs/wsj/s5/steps/dict/prune_pron_candidates.py 4.77 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
  #!/usr/bin/env python
  
  # Copyright 2016  Xiaohui Zhang
  # Apache 2.0.
  
  from __future__ import print_function
  from __future__ import division
  from collections import defaultdict
  import argparse
  import sys
  import math
  
  def GetArgs():
      parser = argparse.ArgumentParser(description = "Prune pronunciation candidates based on soft-counts from lattice-alignment"
                                       "outputs, and a reference lexicon. Basically, for each word we sort all pronunciation"
                                       "cadidates according to their soft-counts, and then select the top r * N candidates"
                                       "(For words in the reference lexicon, N = # pron variants given by the reference"
                                       "lexicon; For oov words, N = avg. # pron variants per word in the reference lexicon)."
                                       "r is a user-specified constant, like 2.",
                                       epilog = "See steps/dict/learn_lexicon_greedy.sh for example")
  
      parser.add_argument("--r", type = float, default = "2.0",
                          help = "a user-specified ratio parameter which determines how many"
                          "pronunciation candidates we want to keep for each word.")
      parser.add_argument("pron_stats", metavar = "<pron-stats>", type = str,
                          help = "File containing soft-counts of all pronounciation candidates; "
                          "each line must be <soft-counts> <word> <phones>")
      parser.add_argument("ref_lexicon", metavar = "<ref-lexicon>", type = str,
                          help = "Reference lexicon file, where we obtain # pron variants for"
                          "each word, based on which we prune the pron candidates.")
      parser.add_argument("pruned_prons", metavar = "<pruned-prons>", type = str,
                          help = "A file in lexicon format, which contains prons we want to" 
                          "prune away from the pron_stats file.")
  
      print (' '.join(sys.argv), file=sys.stderr)
  
      args = parser.parse_args()
      args = CheckArgs(args)
  
      return args
  
  def CheckArgs(args):
      args.pron_stats_handle = open(args.pron_stats)
      args.ref_lexicon_handle = open(args.ref_lexicon)
      if args.pruned_prons == "-":
          args.pruned_prons_handle = sys.stdout
      else:
          args.pruned_prons_handle = open(args.pruned_prons, "w")
      return args
  
  def ReadStats(pron_stats_handle):
      stats = defaultdict(list)
      for line in pron_stats_handle.readlines():
          splits = line.strip().split()
          if len(splits) == 0:
              continue
          if len(splits) < 2:
              raise Exception('Invalid format of line ' + line
                                  + ' in stats file.')
          count = float(splits[0])
          word = splits[1]
          phones = ' '.join(splits[2:])
          stats[word].append((phones, count))
  
      for word, entry in stats.items():
          entry.sort(key=lambda x: x[1])
      return stats
  
  def ReadLexicon(ref_lexicon_handle):
      ref_lexicon = defaultdict(set)
      for line in ref_lexicon_handle.readlines():
          splits = line.strip().split()
          if len(splits) == 0:
              continue
          if len(splits) < 2:
              raise Exception('Invalid format of line ' + line
                                  + ' in lexicon file.')
          word = splits[0]
          try:
              phones = ' '.join(splits[2:])
          except ValueError:
              phones = ' '.join(splits[1:])
          ref_lexicon[word].add(phones)
      return ref_lexicon
  
  def PruneProns(args, stats, ref_lexicon):
      # Compute the average # pron variants counts per word in the reference lexicon.
      num_words_ref = 0
      num_prons_ref = 0
      for word, prons in ref_lexicon.items():
          num_words_ref += 1
          num_prons_ref += len(prons)
      avg_variants_counts_ref = math.ceil(float(num_prons_ref) / float(num_words_ref))
  
      for word, entry in stats.items():
          if word in ref_lexicon:
              variants_counts = args.r * len(ref_lexicon[word])
          else:
              variants_counts = args.r * avg_variants_counts_ref
          num_variants = 0
          while num_variants < variants_counts:
              try:
                  pron, prob = entry.pop()
                  if word not in ref_lexicon or pron not in ref_lexicon[word]:
                      num_variants += 1
              except IndexError:
                  break
          
      for word, entry in stats.items():
          for pron, prob in entry:
              if word not in ref_lexicon or pron not in ref_lexicon[word]:
                  print('{0} {1}'.format(word, pron), file=args.pruned_prons_handle)
  
  def Main():
      args = GetArgs()
      ref_lexicon = ReadLexicon(args.ref_lexicon_handle)
      stats = ReadStats(args.pron_stats_handle)
      PruneProns(args, stats, ref_lexicon)
  
  if __name__ == "__main__":
      Main()