reverse_arpa.py
5.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# -*- coding: utf-8 -*-
# Copyright 2012 Mirko Hannemann BUT, mirko.hannemann@gmail.com
import sys
import codecs # for UTF-8/unicode
if len(sys.argv) != 2:
print 'usage: reverse_arpa arpa.in'
sys.exit()
arpaname = sys.argv[1]
#\data\
#ngram 1=4
#ngram 2=2
#ngram 3=2
#
#\1-grams:
#-5.234679 a -3.3
#-3.456783 b
#0.0000000 <s> -2.5
#-4.333333 </s>
#
#\2-grams:
#-1.45678 a b -3.23
#-1.30490 <s> a -4.2
#
#\3-grams:
#-0.34958 <s> a b
#-0.23940 a b </s>
#\end\
# read language model in ARPA format
try:
file = codecs.open(arpaname, "r", "utf-8")
except IOError:
print 'file not found: ' + arpaname
sys.exit()
text=file.readline()
while (text and text[:6] != "\\data\\"): text=file.readline()
if not text:
print "invalid ARPA file"
sys.exit()
#print text,
while (text and text[:5] != "ngram"): text=file.readline()
# get ngram counts
cngrams=[]
n=0
while (text and text[:5] == "ngram"):
ind = text.split("=")
counts = int(ind[1].strip())
r = ind[0].split()
read_n = int(r[1].strip())
if read_n != n+1:
print "invalid ARPA file:", text
sys.exit()
n = read_n
cngrams.append(counts)
#print text,
text=file.readline()
# read all n-grams order by order
sentprob = 0.0 # sentence begin unigram
ngrams=[]
inf=float("inf")
for n in range(1,len(cngrams)+1): # unigrams, bigrams, trigrams
while (text and "-grams:" not in text): text=file.readline()
if n != int(text[1]):
print "invalid ARPA file:", text
sys.exit()
#print text,cngrams[n-1]
this_ngrams={} # stores all read ngrams
for ng in range(cngrams[n-1]):
while (text and len(text.split())<2):
text=file.readline()
if (not text) or ((len(text.split())==1) and (("-grams:" in text) or (text[:5] == "\\end\\"))): break
if (not text) or ((len(text.split())==1) and (("-grams:" in text) or (text[:5] == "\\end\\"))):
break # to deal with incorrect ARPA files
entry = text.split()
prob = float(entry[0])
if len(entry)>n+1:
back = float(entry[-1])
words = entry[1:n+1]
else:
back = 0.0
words = entry[1:]
ngram = " ".join(words)
if (n==1) and words[0]=="<s>":
sentprob = prob
prob = 0.0
this_ngrams[ngram] = (prob,back)
#print prob,ngram.encode("utf-8"),back
for x in range(n-1,0,-1):
# add all missing backoff ngrams for reversed lm
l_ngram = " ".join(words[:x]) # shortened ngram
r_ngram = " ".join(words[1:1+x]) # shortened ngram with offset one
if l_ngram not in ngrams[x-1]: # create missing ngram
ngrams[x-1][l_ngram] = (0.0,inf)
#print ngram, "create 0.0", l_ngram, "inf"
if r_ngram not in ngrams[x-1]: # create missing ngram
ngrams[x-1][r_ngram] = (0.0,inf)
#print ngram, "create 0.0", r_ngram, "inf",x,n,h_ngram
# add all missing backoff ngrams for forward lm
h_ngram = " ".join(words[n-x:]) # shortened history
if h_ngram not in ngrams[x-1]: # create missing ngram
ngrams[x-1][h_ngram] = (0.0,inf)
#print "create inf", h_ngram, "0.0"
text=file.readline()
if (not text) or ((len(text.split())==1) and (("-grams:" in text) or (text[:5] == "\\end\\"))): break
ngrams.append(this_ngrams)
while (text and text[:5] != "\\end\\"): text=file.readline()
if not text:
print "invalid ARPA file"
sys.exit()
file.close()
#print text,
#fourgram "maxent" model (b(ABCD)=0):
#p(A)+b(A) A 0
#p(AB)+b(AB)-b(A)-p(B) AB 0
#p(ABC)+b(ABC)-b(AB)-p(BC) ABC 0
#p(ABCD)+b(ABCD)-b(ABC)-p(BCD) ABCD 0
#fourgram reverse ARPA model (b(ABCD)=0):
#p(A)+b(A) A 0
#p(AB)+b(AB)-p(B)+p(A) BA 0
#p(ABC)+b(ABC)-p(BC)+p(AB)-p(B)+p(A) CBA 0
#p(ABCD)+b(ABCD)-p(BCD)+p(ABC)-p(BC)+p(AB)-p(B)+p(A) DCBA 0
# compute new reversed ARPA model
print "\\data\\"
for n in range(1,len(cngrams)+1): # unigrams, bigrams, trigrams
print "ngram "+str(n)+"="+str(len(ngrams[n-1].keys()))
offset = 0.0
for n in range(1,len(cngrams)+1): # unigrams, bigrams, trigrams
print "\\"+str(n)+"-grams:"
keys = ngrams[n-1].keys()
keys.sort()
for ngram in keys:
prob = ngrams[n-1][ngram]
# reverse word order
words = ngram.split()
rstr = " ".join(reversed(words))
# swap <s> and </s>
rev_ngram = rstr.replace("<s>","<temp>").replace("</s>","<s>").replace("<temp>","</s>")
revprob = prob[0]
if (prob[1] != inf): # only backoff weights from not newly created ngrams
revprob = revprob + prob[1]
#print prob[0],prob[1]
# sum all missing terms in decreasing ngram order
for x in range(n-1,0,-1):
l_ngram = " ".join(words[:x]) # shortened ngram
if l_ngram not in ngrams[x-1]:
sys.stderr.write(rev_ngram+": not found "+l_ngram+"\n")
p_l = ngrams[x-1][l_ngram][0]
#print p_l,l_ngram
revprob = revprob + p_l
r_ngram = " ".join(words[1:1+x]) # shortened ngram with offset one
if r_ngram not in ngrams[x-1]:
sys.stderr.write(rev_ngram+": not found "+r_ngram+"\n")
p_r = ngrams[x-1][r_ngram][0]
#print -p_r,r_ngram
revprob = revprob - p_r
if n != len(cngrams): #not highest order
back = 0.0
if rev_ngram[:3] == "<s>": # special handling since arpa2fst ignores <s> weight
if n == 1:
offset = revprob # remember <s> weight
revprob = sentprob # apply <s> weight from forward model
back = offset
elif n == 2:
revprob = revprob + offset # add <s> weight to bigrams starting with <s>
if (prob[1] != inf): # only backoff weights from not newly created ngrams
print revprob,rev_ngram.encode("utf-8"),back
else:
print revprob,rev_ngram.encode("utf-8"),"-100000.0"
else: # highest order - no backoff weights
if (n==2) and (rev_ngram[:3] == "<s>"): revprob = revprob + offset
print revprob,rev_ngram.encode("utf-8")
print "\\end\\"