prune_pron_candidates.py
4.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python
# Copyright 2016 Xiaohui Zhang
# Apache 2.0.
from __future__ import print_function
from __future__ import division
from collections import defaultdict
import argparse
import sys
import math
def GetArgs():
parser = argparse.ArgumentParser(description = "Prune pronunciation candidates based on soft-counts from lattice-alignment"
"outputs, and a reference lexicon. Basically, for each word we sort all pronunciation"
"cadidates according to their soft-counts, and then select the top r * N candidates"
"(For words in the reference lexicon, N = # pron variants given by the reference"
"lexicon; For oov words, N = avg. # pron variants per word in the reference lexicon)."
"r is a user-specified constant, like 2.",
epilog = "See steps/dict/learn_lexicon_greedy.sh for example")
parser.add_argument("--r", type = float, default = "2.0",
help = "a user-specified ratio parameter which determines how many"
"pronunciation candidates we want to keep for each word.")
parser.add_argument("pron_stats", metavar = "<pron-stats>", type = str,
help = "File containing soft-counts of all pronounciation candidates; "
"each line must be <soft-counts> <word> <phones>")
parser.add_argument("ref_lexicon", metavar = "<ref-lexicon>", type = str,
help = "Reference lexicon file, where we obtain # pron variants for"
"each word, based on which we prune the pron candidates.")
parser.add_argument("pruned_prons", metavar = "<pruned-prons>", type = str,
help = "A file in lexicon format, which contains prons we want to"
"prune away from the pron_stats file.")
print (' '.join(sys.argv), file=sys.stderr)
args = parser.parse_args()
args = CheckArgs(args)
return args
def CheckArgs(args):
args.pron_stats_handle = open(args.pron_stats)
args.ref_lexicon_handle = open(args.ref_lexicon)
if args.pruned_prons == "-":
args.pruned_prons_handle = sys.stdout
else:
args.pruned_prons_handle = open(args.pruned_prons, "w")
return args
def ReadStats(pron_stats_handle):
stats = defaultdict(list)
for line in pron_stats_handle.readlines():
splits = line.strip().split()
if len(splits) == 0:
continue
if len(splits) < 2:
raise Exception('Invalid format of line ' + line
+ ' in stats file.')
count = float(splits[0])
word = splits[1]
phones = ' '.join(splits[2:])
stats[word].append((phones, count))
for word, entry in stats.items():
entry.sort(key=lambda x: x[1])
return stats
def ReadLexicon(ref_lexicon_handle):
ref_lexicon = defaultdict(set)
for line in ref_lexicon_handle.readlines():
splits = line.strip().split()
if len(splits) == 0:
continue
if len(splits) < 2:
raise Exception('Invalid format of line ' + line
+ ' in lexicon file.')
word = splits[0]
try:
phones = ' '.join(splits[2:])
except ValueError:
phones = ' '.join(splits[1:])
ref_lexicon[word].add(phones)
return ref_lexicon
def PruneProns(args, stats, ref_lexicon):
# Compute the average # pron variants counts per word in the reference lexicon.
num_words_ref = 0
num_prons_ref = 0
for word, prons in ref_lexicon.items():
num_words_ref += 1
num_prons_ref += len(prons)
avg_variants_counts_ref = math.ceil(float(num_prons_ref) / float(num_words_ref))
for word, entry in stats.items():
if word in ref_lexicon:
variants_counts = args.r * len(ref_lexicon[word])
else:
variants_counts = args.r * avg_variants_counts_ref
num_variants = 0
while num_variants < variants_counts:
try:
pron, prob = entry.pop()
if word not in ref_lexicon or pron not in ref_lexicon[word]:
num_variants += 1
except IndexError:
break
for word, entry in stats.items():
for pron, prob in entry:
if word not in ref_lexicon or pron not in ref_lexicon[word]:
print('{0} {1}'.format(word, pron), file=args.pruned_prons_handle)
def Main():
args = GetArgs()
ref_lexicon = ReadLexicon(args.ref_lexicon_handle)
stats = ReadStats(args.pron_stats_handle)
PruneProns(args, stats, ref_lexicon)
if __name__ == "__main__":
Main()