sausages.h
11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
// lat/sausages.h
// Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
// 2015 Guoguo Chen
// 2019 Dogan Can
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_SAUSAGES_H_
#define KALDI_LAT_SAUSAGES_H_
#include <vector>
#include <map>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "fstext/fstext-lib.h"
#include "lat/kaldi-lattice.h"
namespace kaldi {
/// The implementation of the Minimum Bayes Risk decoding method described in
/// "Minimum Bayes Risk decoding and system combination based on a recursion for
/// edit distance", Haihua Xu, Daniel Povey, Lidia Mangu and Jie Zhu, Computer
/// Speech and Language, 2011
/// This is a slightly more principled way to do Minimum Bayes Risk (MBR) decoding
/// than the standard "Confusion Network" method. Note: MBR decoding aims to
/// minimize the expected word error rate, assuming the lattice encodes the
/// true uncertainty about what was spoken; standard Viterbi decoding gives the
/// most likely utterance, which corresponds to minimizing the expected sentence
/// error rate.
///
/// In addition to giving the MBR output, we also provide a way to get a
/// "Confusion Network" or informally "sausage"-like structure. This is a
/// linear sequence of bins, and in each bin, there is a distribution over
/// words (or epsilon, meaning no word). This is useful for estimating
/// confidence. Note: due to the way these sausages are made, typically there
/// will be, between each bin representing a high-confidence word, a bin
/// in which epsilon (no word) is the most likely word. Inside these bins
/// is where we put possible insertions.
struct MinimumBayesRiskOptions {
/// Boolean configuration parameter: if true, we actually update the hypothesis
/// to do MBR decoding (if false, our output is the MAP decoded output, but we
/// output the stats too (i.e. the confidences)).
bool decode_mbr;
/// Boolean configuration parameter: if true, the 1-best path will 'keep' the <eps> bins,
bool print_silence;
MinimumBayesRiskOptions() : decode_mbr(true), print_silence(false)
{ }
void Register(OptionsItf *opts) {
opts->Register("decode-mbr", &decode_mbr, "If true, do Minimum Bayes Risk "
"decoding (else, Maximum a Posteriori)");
opts->Register("print-silence", &print_silence, "Keep the inter-word '<eps>' "
"bins in the 1-best output (ctm, <eps> can be a 'silence' or a 'deleted' word)");
}
};
/// This class does the word-level Minimum Bayes Risk computation, and gives you
/// either the 1-best MBR output together with the expected Bayes Risk,
/// or a sausage-like structure.
class MinimumBayesRisk {
public:
/// Initialize with compact lattice-- any acoustic scaling etc., is assumed
/// to have been done already.
/// This does the whole computation. You get the output with
/// GetOneBest(), GetBayesRisk(), and GetSausageStats().
MinimumBayesRisk(const CompactLattice &clat,
MinimumBayesRiskOptions opts = MinimumBayesRiskOptions());
// Uses the provided <words> as <R_> instead of using the lattice best path.
// Note that the default value of opts.decode_mbr is true. If you provide 1-best
// hypothesis from MAP decoding, the output ctm from MBR decoding may be
// mismatched with the provided <words> (<words> would be used as the starting
// point of optimization).
MinimumBayesRisk(const CompactLattice &clat,
const std::vector<int32> &words,
MinimumBayesRiskOptions opts = MinimumBayesRiskOptions());
// Uses the provided <words> as <R_> and <times> of bins instead of using the lattice best path.
// Note that the default value of opts.decode_mbr is true. If you provide 1-best
// hypothesis from MAP decoding, the output ctm from MBR decoding may be
// mismatched with the provided <words> (<words> would be used as the starting
// point of optimization).
MinimumBayesRisk(const CompactLattice &clat,
const std::vector<int32> &words,
const std::vector<std::pair<BaseFloat,BaseFloat> > ×,
MinimumBayesRiskOptions opts = MinimumBayesRiskOptions());
const std::vector<int32> &GetOneBest() const { // gets one-best (with no epsilons)
return R_;
}
const std::vector<std::vector<std::pair<BaseFloat, BaseFloat> > > GetTimes() const {
return times_; // returns average (start,end) times for each word in each
// bin. These are raw averages without any processing, i.e. time intervals
// from different bins can overlap.
}
const std::vector<std::pair<BaseFloat, BaseFloat> > GetSausageTimes() const {
return sausage_times_; // returns average (start,end) times for each bin.
// This is typically the weighted average of the times in GetTimes() but can
// be slightly different if the times for the bins overlap, in which case
// the times returned by this method do not overlap unlike the times
// returned by GetTimes().
}
const std::vector<std::pair<BaseFloat, BaseFloat> > &GetOneBestTimes() const {
return one_best_times_; // returns average (start,end) times for each word
// corresponding to an entry in the one-best output. This is typically the
// appropriate subset of the times in GetTimes() but can be slightly
// different if the times for the one-best words overlap, in which case
// the times returned by this method do not overlap unlike the times
// returned by GetTimes().
}
/// Outputs the confidences for the one-best transcript.
const std::vector<BaseFloat> &GetOneBestConfidences() const {
return one_best_confidences_;
}
/// Returns the expected WER over this sentence (assuming model correctness).
BaseFloat GetBayesRisk() const { return L_; }
const std::vector<std::vector<std::pair<int32, BaseFloat> > > &GetSausageStats() const {
return gamma_;
}
private:
void PrepareLatticeAndInitStats(CompactLattice *clat);
/// Minimum-Bayes-Risk Decode. Top-level algorithm. Figure 6 of the paper.
void MbrDecode();
/// Without the 'penalize' argument this gives us the basic edit-distance
/// function l(a,b), as in the paper.
/// With the 'penalize' argument it can be interpreted as the edit distance
/// plus the 'delta' from the paper, except that we make a kind of conceptual
/// bug-fix and only apply the delta if the edit-distance was not already
/// zero. This bug-fix was necessary in order to force all the stats to show
/// up, that should show up, and applying the bug-fix makes the sausage stats
/// significantly less sparse.
inline double l(int32 a, int32 b, bool penalize = false) {
if (a == b) return 0.0;
else return (penalize ? 1.0 + delta() : 1.0);
}
/// returns r_q, in one-based indexing, as in the paper.
inline int32 r(int32 q) { return R_[q-1]; }
/// Figure 4 of the paper; called from AccStats (Fig. 5)
double EditDistance(int32 N, int32 Q,
Vector<double> &alpha,
Matrix<double> &alpha_dash,
Vector<double> &alpha_dash_arc);
/// Figure 5 of the paper. Outputs to gamma_ and L_.
void AccStats();
/// Removes epsilons (symbol 0) from a vector
static void RemoveEps(std::vector<int32> *vec);
// Ensures that between each word in "vec" and at the beginning and end, is
// epsilon (0). (But if no words in vec, just one epsilon)
static void NormalizeEps(std::vector<int32> *vec);
// delta() is a constant used in the algorithm, which penalizes
// the use of certain epsilon transitions in the edit-distance which would cause
// words not to show up in the accumulated edit-distance statistics.
// There has been a conceptual bug-fix versus the way it was presented in
// the paper: we now add delta only if the edit-distance was not already
// zero.
static inline BaseFloat delta() { return 1.0e-05; }
/// Function used to increment map.
static inline void AddToMap(int32 i, double d, std::map<int32, double> *gamma) {
if (d == 0) return;
std::pair<const int32, double> pr(i, d);
std::pair<std::map<int32, double>::iterator, bool> ret = gamma->insert(pr);
if (!ret.second) // not inserted, so add to contents.
ret.first->second += d;
}
struct Arc {
int32 word;
int32 start_node;
int32 end_node;
BaseFloat loglike;
};
MinimumBayesRiskOptions opts_;
/// Arcs in the topologically sorted acceptor form of the word-level lattice,
/// with one final-state. Contains (word-symbol, log-likelihood on arc ==
/// negated cost). Indexed from zero.
std::vector<Arc> arcs_;
/// For each node in the lattice, a list of arcs entering that node. Indexed
/// from 1 (first node == 1).
std::vector<std::vector<int32> > pre_;
std::vector<int32> state_times_; // time of each state in the word lattice,
// indexed from 1 (same index as into pre_)
std::vector<int32> R_; // current 1-best word sequence, normalized to have
// epsilons between each word and at the beginning and end. R in paper...
// caution: indexed from zero, not from 1 as in paper.
double L_; // current averaged edit-distance between lattice and R_.
// \hat{L} in paper.
std::vector<std::vector<std::pair<int32, BaseFloat> > > gamma_;
// The stats we accumulate; these are pairs of (posterior, word-id), and note
// that word-id may be epsilon. Caution: indexed from zero, not from 1 as in
// paper. We sort in reverse order on the second member (posterior), so more
// likely word is first.
std::vector<std::vector<std::pair<BaseFloat, BaseFloat> > > times_;
// The average start and end times for words in each confusion-network bin.
// This is like an average over arcs, of the tau_b and tau_e quantities in
// Appendix C of the paper. Indexed from zero, like gamma_ and R_.
std::vector<std::pair<BaseFloat, BaseFloat> > sausage_times_;
// The average start and end times for each confusion-network bin. This
// is like an average over words, of the tau_b and tau_e quantities in
// Appendix C of the paper. Indexed from zero, like gamma_ and R_.
std::vector<std::pair<BaseFloat, BaseFloat> > one_best_times_;
// The average start and end times for words in the one best output. This
// is like an average over the arcs, of the tau_b and tau_e quantities in
// Appendix C of the paper. Indexed from zero, like gamma_ and R_.
std::vector<BaseFloat> one_best_confidences_;
// vector of confidences for the 1-best output (which could be
// the MAP output if opts_.decode_mbr == false, or the MBR output otherwise).
// Indexed by the same index as one_best_times_.
struct GammaCompare{
// should be like operator <. But we want reverse order
// on the 2nd element (posterior), so it'll be like operator
// > that looks first at the posterior.
bool operator () (const std::pair<int32, BaseFloat> &a,
const std::pair<int32, BaseFloat> &b) const {
if (a.second > b.second) return true;
else if (a.second < b.second) return false;
else return a.first > b.first;
}
};
};
} // namespace kaldi
#endif // KALDI_LAT_SAUSAGES_H_