arpa-lm-compiler.cc
14.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
// lm/arpa-lm-compiler.cc
// Copyright 2009-2011 Gilles Boulianne
// Copyright 2016 Smart Action LLC (kkm)
// Copyright 2017 Xiaohui Zhang
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <limits>
#include <sstream>
#include <utility>
#include "base/kaldi-math.h"
#include "lm/arpa-lm-compiler.h"
#include "util/stl-utils.h"
#include "util/text-utils.h"
#include "fstext/remove-eps-local.h"
namespace kaldi {
class ArpaLmCompilerImplInterface {
public:
virtual ~ArpaLmCompilerImplInterface() { }
virtual void ConsumeNGram(const NGram& ngram, bool is_highest) = 0;
};
namespace {
typedef int32 StateId;
typedef int32 Symbol;
// GeneralHistKey can represent state history in an arbitrarily large n
// n-gram model with symbol ids fitting int32.
class GeneralHistKey {
public:
// Construct key from being and end iterators.
template<class InputIt>
GeneralHistKey(InputIt begin, InputIt end) : vector_(begin, end) { }
// Construct empty history key.
GeneralHistKey() : vector_() { }
// Return tails of the key as a GeneralHistKey. The tails of an n-gram
// w[1..n] is the sequence w[2..n] (and the heads is w[1..n-1], but the
// key class does not need this operartion).
GeneralHistKey Tails() const {
return GeneralHistKey(vector_.begin() + 1, vector_.end());
}
// Keys are equal if represent same state.
friend bool operator==(const GeneralHistKey& a, const GeneralHistKey& b) {
return a.vector_ == b.vector_;
}
// Public typename HashType for hashing.
struct HashType : public std::unary_function<GeneralHistKey, size_t> {
size_t operator()(const GeneralHistKey& key) const {
return VectorHasher<Symbol>().operator()(key.vector_);
}
};
private:
std::vector<Symbol> vector_;
};
// OptimizedHistKey combines 3 21-bit symbol ID values into one 64-bit
// machine word. allowing significant memory reduction and some runtime
// benefit over GeneralHistKey. Since 3 symbols are enough to track history
// in a 4-gram model, this optimized key is used for smaller models with up
// to 4-gram and symbol values up to 2^21-1.
//
// See GeneralHistKey for interface requrements of a key class.
class OptimizedHistKey {
public:
enum {
kShift = 21, // 21 * 3 = 63 bits for data.
kMaxData = (1 << kShift) - 1
};
template<class InputIt>
OptimizedHistKey(InputIt begin, InputIt end) : data_(0) {
for (uint32 shift = 0; begin != end; ++begin, shift += kShift) {
data_ |= static_cast<uint64>(*begin) << shift;
}
}
OptimizedHistKey() : data_(0) { }
OptimizedHistKey Tails() const {
return OptimizedHistKey(data_ >> kShift);
}
friend bool operator==(const OptimizedHistKey& a, const OptimizedHistKey& b) {
return a.data_ == b.data_;
}
struct HashType : public std::unary_function<OptimizedHistKey, size_t> {
size_t operator()(const OptimizedHistKey& key) const { return key.data_; }
};
private:
explicit OptimizedHistKey(uint64 data) : data_(data) { }
uint64 data_;
};
} // namespace
template <class HistKey>
class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface {
public:
ArpaLmCompilerImpl(ArpaLmCompiler* parent, fst::StdVectorFst* fst,
Symbol sub_eps);
virtual void ConsumeNGram(const NGram &ngram, bool is_highest);
private:
StateId AddStateWithBackoff(HistKey key, float backoff);
void CreateBackoff(HistKey key, StateId state, float weight);
ArpaLmCompiler *parent_; // Not owned.
fst::StdVectorFst* fst_; // Not owned.
Symbol bos_symbol_;
Symbol eos_symbol_;
Symbol sub_eps_;
StateId eos_state_;
typedef unordered_map<HistKey, StateId,
typename HistKey::HashType> HistoryMap;
HistoryMap history_;
};
template <class HistKey>
ArpaLmCompilerImpl<HistKey>::ArpaLmCompilerImpl(
ArpaLmCompiler* parent, fst::StdVectorFst* fst, Symbol sub_eps)
: parent_(parent), fst_(fst), bos_symbol_(parent->Options().bos_symbol),
eos_symbol_(parent->Options().eos_symbol), sub_eps_(sub_eps) {
// The algorithm maintains state per history. The 0-gram is a special state
// for emptry history. All unigrams (including BOS) backoff into this state.
StateId zerogram = fst_->AddState();
history_[HistKey()] = zerogram;
// Also, if </s> is not treated as epsilon, create a common end state for
// all transitions acepting the </s>, since they do not back off. This small
// optimization saves about 2% states in an average grammar.
if (sub_eps_ == 0) {
eos_state_ = fst_->AddState();
fst_->SetFinal(eos_state_, 0);
}
}
template <class HistKey>
void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(const NGram &ngram,
bool is_highest) {
// Generally, we do the following. Suppose we are adding an n-gram "A B
// C". Then find the node for "A B", add a new node for "A B C", and connect
// them with the arc accepting "C" with the specified weight. Also, add a
// backoff arc from the new "A B C" node to its backoff state "B C".
//
// Two notable exceptions are the highest order n-grams, and final n-grams.
//
// When adding a highest order n-gram (e. g., our "A B C" is in a 3-gram LM),
// the following optimization is performed. There is no point adding a node
// for "A B C" with a "C" arc from "A B", since there will be no other
// arcs ingoing to this node, and an epsilon backoff arc into the backoff
// model "B C", with the weight of \bar{1}. To save a node, create an arc
// accepting "C" directly from "A B" to "B C". This saves as many nodes
// as there are the highest order n-grams, which is typically about half
// the size of a large 3-gram model.
//
// Indeed, this does not apply to n-grams ending in EOS, since they do not
// back off. These are special, as they do not have a back-off state, and
// the node for "(..anything..) </s>" is always final. These are handled
// in one of the two possible ways, If symbols <s> and </s> are being
// replaced by epsilons, neither node nor arc is created, and the logprob
// of the n-gram is applied to its source node as final weight. If <s> and
// </s> are preserved, then a special final node for </s> is allocated and
// used as the destination of the "</s>" acceptor arc.
HistKey heads(ngram.words.begin(), ngram.words.end() - 1);
typename HistoryMap::iterator source_it = history_.find(heads);
if (source_it == history_.end()) {
// There was no "A B", therefore the probability of "A B C" is zero.
// Print a warning and discard current n-gram.
if (parent_->ShouldWarn())
KALDI_WARN << parent_->LineReference()
<< " skipped: no parent (n-1)-gram exists";
return;
}
StateId source = source_it->second;
StateId dest;
Symbol sym = ngram.words.back();
float weight = -ngram.logprob;
if (sym == sub_eps_ || sym == 0) {
KALDI_ERR << " <eps> or disambiguation symbol " << sym << "found in the ARPA file. ";
}
if (sym == eos_symbol_) {
if (sub_eps_ == 0) {
// Keep </s> as a real symbol when not substituting.
dest = eos_state_;
} else {
// Treat </s> as if it was epsilon: mark source final, with the weight
// of the n-gram.
fst_->SetFinal(source, weight);
return;
}
} else {
// For the highest order n-gram, this may find an existing state, for
// non-highest, will create one (unless there are duplicate n-grams
// in the grammar, which cannot be reliably detected if highest order,
// so we better do not do that at all).
dest = AddStateWithBackoff(
HistKey(ngram.words.begin() + (is_highest ? 1 : 0),
ngram.words.end()),
-ngram.backoff);
}
if (sym == bos_symbol_) {
weight = 0; // Accepting <s> is always free.
if (sub_eps_ == 0) {
// <s> is as a real symbol, only accepted in the start state.
source = fst_->AddState();
fst_->SetStart(source);
} else {
// The new state for <s> unigram history *is* the start state.
fst_->SetStart(dest);
return;
}
}
// Add arc from source to dest, whichever way it was found.
fst_->AddArc(source, fst::StdArc(sym, sym, weight, dest));
return;
}
// Find or create a new state for n-gram defined by key, and ensure it has a
// backoff transition. The key is either the current n-gram for all but
// highest orders, or the tails of the n-gram for the highest order. The
// latter arises from the chain-collapsing optimization described above.
template <class HistKey>
StateId ArpaLmCompilerImpl<HistKey>::AddStateWithBackoff(HistKey key,
float backoff) {
typename HistoryMap::iterator dest_it = history_.find(key);
if (dest_it != history_.end()) {
// Found an existing state in the history map. Invariant: if the state in
// the map, then its backoff arc is in the FST. We are done.
return dest_it->second;
}
// Otherwise create a new state and its backoff arc, and register in the map.
StateId dest = fst_->AddState();
history_[key] = dest;
CreateBackoff(key.Tails(), dest, backoff);
return dest;
}
// Create a backoff arc for a state. Key is a backoff destination that may or
// may not exist. When the destination is not found, naturally fall back to
// the lower order model, and all the way down until one is found (since the
// 0-gram model is always present, the search is guaranteed to terminate).
template <class HistKey>
inline void ArpaLmCompilerImpl<HistKey>::CreateBackoff(
HistKey key, StateId state, float weight) {
typename HistoryMap::iterator dest_it = history_.find(key);
while (dest_it == history_.end()) {
key = key.Tails();
dest_it = history_.find(key);
}
// The arc should transduce either <eos> or #0 to <eps>, depending on the
// epsilon substitution mode. This is the only case when input and output
// label may differ.
fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second));
}
ArpaLmCompiler::~ArpaLmCompiler() {
if (impl_ != NULL)
delete impl_;
}
void ArpaLmCompiler::HeaderAvailable() {
KALDI_ASSERT(impl_ == NULL);
// Use optimized implementation if the grammar is 4-gram or less, and the
// maximum attained symbol id will fit into the optimized range.
int64 max_symbol = 0;
if (Symbols() != NULL)
max_symbol = Symbols()->AvailableKey() - 1;
// If augmenting the symbol table, assume the wors case when all words in
// the model being read are novel.
if (Options().oov_handling == ArpaParseOptions::kAddToSymbols)
max_symbol += NgramCounts()[0];
if (NgramCounts().size() <= 4 && max_symbol < OptimizedHistKey::kMaxData) {
impl_ = new ArpaLmCompilerImpl<OptimizedHistKey>(this, &fst_, sub_eps_);
} else {
impl_ = new ArpaLmCompilerImpl<GeneralHistKey>(this, &fst_, sub_eps_);
KALDI_LOG << "Reverting to slower state tracking because model is large: "
<< NgramCounts().size() << "-gram with symbols up to "
<< max_symbol;
}
}
void ArpaLmCompiler::ConsumeNGram(const NGram &ngram) {
// <s> is invalid in tails, </s> in heads of an n-gram.
for (int i = 0; i < ngram.words.size(); ++i) {
if ((i > 0 && ngram.words[i] == Options().bos_symbol) ||
(i + 1 < ngram.words.size()
&& ngram.words[i] == Options().eos_symbol)) {
if (ShouldWarn())
KALDI_WARN << LineReference()
<< " skipped: n-gram has invalid BOS/EOS placement";
return;
}
}
bool is_highest = ngram.words.size() == NgramCounts().size();
impl_->ConsumeNGram(ngram, is_highest);
}
void ArpaLmCompiler::RemoveRedundantStates() {
fst::StdArc::Label backoff_symbol = sub_eps_;
if (backoff_symbol == 0) {
// The method of removing redundant states implemented in this function
// leads to slow determinization of L o G when people use the older style of
// usage of arpa2fst where the --disambig-symbol option was not specified.
// The issue seems to be that it creates a non-deterministic FST, while G is
// supposed to be deterministic. By 'return'ing below, we just disable this
// method if people were using an older script. This method isn't really
// that consequential anyway, and people will move to the newer-style
// scripts (see current utils/format_lm.sh), so this isn't much of a
// problem.
return;
}
fst::StdArc::StateId num_states = fst_.NumStates();
// replace the #0 symbols on the input of arcs out of redundant states (states
// that are not final and have only a backoff arc leaving them), with <eps>.
for (fst::StdArc::StateId state = 0; state < num_states; state++) {
if (fst_.NumArcs(state) == 1 && fst_.Final(state) == fst::TropicalWeight::Zero()) {
fst::MutableArcIterator<fst::StdVectorFst> iter(&fst_, state);
fst::StdArc arc = iter.Value();
if (arc.ilabel == backoff_symbol) {
arc.ilabel = 0;
iter.SetValue(arc);
}
}
}
// we could call fst::RemoveEps, and it would have the same effect in normal
// cases, where backoff_symbol != 0 and there are no epsilons in unexpected
// places, but RemoveEpsLocal is a bit safer in case something weird is going
// on; it guarantees not to blow up the FST.
fst::RemoveEpsLocal(&fst_);
KALDI_LOG << "Reduced num-states from " << num_states << " to "
<< fst_.NumStates();
}
void ArpaLmCompiler::Check() const {
if (fst_.Start() == fst::kNoStateId) {
KALDI_ERR << "Arpa file did not contain the beginning-of-sentence symbol "
<< Symbols()->Find(Options().bos_symbol) << ".";
}
}
void ArpaLmCompiler::ReadComplete() {
fst_.SetInputSymbols(Symbols());
fst_.SetOutputSymbols(Symbols());
RemoveRedundantStates();
Check();
}
} // namespace kaldi