online-gmm-decoding.h
12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
// online2/online-gmm-decoding.h
// Copyright 2014 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_ONLINE2_ONLINE_GMM_DECODING_H_
#define KALDI_ONLINE2_ONLINE_GMM_DECODING_H_
#include <string>
#include <vector>
#include <deque>
#include "matrix/matrix-lib.h"
#include "util/common-utils.h"
#include "base/kaldi-error.h"
#include "transform/basis-fmllr-diag-gmm.h"
#include "transform/fmllr-diag-gmm.h"
#include "online2/online-feature-pipeline.h"
#include "online2/online-gmm-decodable.h"
#include "online2/online-endpoint.h"
#include "decoder/lattice-faster-online-decoder.h"
#include "hmm/transition-model.h"
#include "gmm/am-diag-gmm.h"
#include "hmm/posterior.h"
namespace kaldi {
/// @addtogroup onlinedecoding OnlineDecoding
/// @{
/// This configuration class controls when to re-estimate the basis-fMLLR during
/// online decoding. The basic model is to re-estimate it on a certain time t
/// (e.g. after 1 second) and then at a set of times forming a geometric series,
/// e.g. 1.5, 1.5^2, etc. We specify different configurations for the first
/// utterance of a speaker (which requires more frequent adaptation), and for
/// subsequent utterances. We also re-estimate fMLLR at the end of every
/// utterance, but this is done directly from the calling code, not by the class
/// SingleUtteranceGmmDecoder.
struct OnlineGmmDecodingAdaptationPolicyConfig {
BaseFloat adaptation_first_utt_delay;
BaseFloat adaptation_first_utt_ratio;
BaseFloat adaptation_delay;
BaseFloat adaptation_ratio;
OnlineGmmDecodingAdaptationPolicyConfig():
adaptation_first_utt_delay(2.0),
adaptation_first_utt_ratio(1.5),
adaptation_delay(5.0),
adaptation_ratio(2.0) { }
void Register(OptionsItf *opts) {
opts->Register("adaptation-first-utt-delay", &adaptation_first_utt_delay,
"Delay before first basis-fMLLR adaptation for first utterance "
"of each speaker");
opts->Register("adaptation-first-utt-ratio", &adaptation_first_utt_ratio,
"Ratio that controls frequency of fMLLR adaptation for first "
"utterance of each speaker");
opts->Register("adaptation-delay", &adaptation_delay,
"Delay before first basis-fMLLR adaptation for not-first "
"utterances of each speaker");
opts->Register("adaptation-ratio", &adaptation_ratio,
"Ratio that controls frequency of fMLLR adaptation for "
"not-first utterances of each speaker");
}
/// Check that configuration values make sense.
void Check() const;
/// This function returns true if we are scheduled
/// to re-estimate fMLLR somewhere in the interval
/// [ chunk_begin_secs, chunk_end_secs ).
bool DoAdapt(BaseFloat chunk_begin_secs,
BaseFloat chunk_end_secs,
bool is_first_utterance) const;
};
struct OnlineGmmDecodingConfig {
BaseFloat fmllr_lattice_beam;
BasisFmllrOptions basis_opts; // options for basis-fMLLR adaptation.
LatticeFasterDecoderConfig faster_decoder_opts;
OnlineGmmDecodingAdaptationPolicyConfig adaptation_policy_opts;
// rxfilename for model trained with online-CMN features
// (only needed if different from model_rxfilename)
std::string online_alimdl_rxfilename;
// rxfilename for model used for estimating fMLLR transforms
std::string model_rxfilename;
// rxfilename for possible discriminatively trained model
// (only needed if different from model_rxfilename)
std::string rescore_model_rxfilename;
// rxfilename for the BasisFmllrEstimate object containing the basis
// used for basis-fMLLR.
std::string fmllr_basis_rxfilename;
BaseFloat acoustic_scale;
std::string silence_phones;
BaseFloat silence_weight;
OnlineGmmDecodingConfig(): fmllr_lattice_beam(3.0), acoustic_scale(0.1),
silence_weight(0.1) { }
void Register(OptionsItf *opts) {
{ // register basis_opts with prefix, there are getting to be too many
// options.
ParseOptions basis_po("basis", opts);
basis_opts.Register(&basis_po);
}
adaptation_policy_opts.Register(opts);
faster_decoder_opts.Register(opts);
opts->Register("acoustic-scale", &acoustic_scale,
"Scaling factor for acoustic likelihoods");
opts->Register("silence-phones", &silence_phones,
"Colon-separated list of integer ids of silence phones, e.g. "
"1:2:3 (affects adaptation).");
opts->Register("silence-weight", &silence_weight,
"Weight applied to silence frames for fMLLR estimation (if "
"--silence-phones option is supplied)");
opts->Register("fmllr-lattice-beam", &fmllr_lattice_beam, "Beam used in "
"pruning lattices for fMLLR estimation");
opts->Register("online-alignment-model", &online_alimdl_rxfilename,
"(Extended) filename for model trained with online CMN "
"features, e.g. from apply-cmvn-online.");
opts->Register("model", &model_rxfilename, "(Extended) filename for model, "
"typically the one used for fMLLR computation. Required option.");
opts->Register("rescore-model", &rescore_model_rxfilename, "(Extended) filename "
"for model to rescore lattices with, e.g. discriminatively trained"
"model, if it differs from that supplied to --model option. Must"
"have the same tree.");
opts->Register("fmllr-basis", &fmllr_basis_rxfilename, "(Extended) filename "
"of fMLLR basis object, as output by gmm-basis-fmllr-training");
}
};
/**
This class is used to read, store and give access to the models used for 3
phases of decoding (first-pass with online-CMN features; the ML models used
for estimating transforms; and the discriminatively trained models). It
takes care of the logic whereby if, say, the last model isn't given we
default to the second model, and so on, and it interpretes the filenames
from the config object. It is passed as a const reference to other
objects in this header. */
class OnlineGmmDecodingModels {
public:
OnlineGmmDecodingModels(const OnlineGmmDecodingConfig &config);
const TransitionModel &GetTransitionModel() const;
const AmDiagGmm &GetOnlineAlignmentModel() const;
const AmDiagGmm &GetModel() const;
const AmDiagGmm &GetFinalModel() const;
const BasisFmllrEstimate &GetFmllrBasis() const;
private:
// The transition-model is only needed for its integer ids, and these need to
// be identical for all 3 models, so we only store one (it doesn't matter
// which one).
TransitionModel tmodel_;
// The model trained with online-CMVN features
// (if supplied, otherwise use model_)
AmDiagGmm online_alignment_model_;
// The ML-trained model used to get transforms (required)
AmDiagGmm model_;
// The discriminatively trained model
// (if supplied, otherwise use model_)
AmDiagGmm rescore_model_;
// The following object contains the basis elements for
// "Basis fMLLR".
BasisFmllrEstimate fmllr_basis_;
};
struct OnlineGmmAdaptationState {
OnlineCmvnState cmvn_state;
FmllrDiagGmmAccs spk_stats;
Matrix<BaseFloat> transform;
// Writing and reading of the state of the object
void Write(std::ostream &out_stream, bool binary) const;
void Read(std::istream &in_stream, bool binary);
};
/**
You will instantiate this class when you want to decode a single
utterance using the online-decoding setup. This is an alternative
to manually putting things together yourself.
*/
class SingleUtteranceGmmDecoder {
public:
SingleUtteranceGmmDecoder(const OnlineGmmDecodingConfig &config,
const OnlineGmmDecodingModels &models,
const OnlineFeaturePipeline &feature_prototype,
const fst::Fst<fst::StdArc> &fst,
const OnlineGmmAdaptationState &adaptation_state);
OnlineFeaturePipeline &FeaturePipeline() { return *feature_pipeline_; }
/// advance the decoding as far as we can. May also estimate fMLLR after
/// advancing the decoding, depending on the configuration values in
/// config_.adaptation_policy_opts. [Note: we expect the user will also call
/// EstimateFmllr() at utterance end, which should generally improve the
/// quality of the estimated transforms, although we don't rely on this].
void AdvanceDecoding();
/// Finalize the decoding. Cleanups and prunes remaining tokens, so the final result
/// is faster to obtain.
void FinalizeDecoding();
/// Returns true if we already have an fMLLR transform. The user will
/// already know this; the call is for convenience.
bool HaveTransform() const;
/// Estimate the [basis-]fMLLR transform and apply it to the features.
/// This will get used if you call RescoreLattice() or if you just
/// continue decoding; however to get it applied retroactively
/// you'd have to call RescoreLattice().
/// "end_of_utterance" just affects how we interpret the final-probs in the
/// lattice. This should generally be true if you think you've reached
/// the end of the grammar, and false otherwise.
void EstimateFmllr(bool end_of_utterance);
void GetAdaptationState(OnlineGmmAdaptationState *adaptation_state) const;
/// Gets the lattice. If rescore_if_needed is true, and if there is any point
/// in rescoring the state-level lattice (see RescoringIsNeeded()), it will
/// rescore the lattice. The output lattice has any acoustic scaling in it
/// (which will typically be desirable in an online-decoding context); if you
/// want an un-scaled lattice, scale it using ScaleLattice() with the inverse
/// of the acoustic weight. "end_of_utterance" will be true if you want the
/// final-probs to be included.
void GetLattice(bool rescore_if_needed,
bool end_of_utterance,
CompactLattice *clat) const;
/// Outputs an FST corresponding to the single best path through the current
/// lattice. If "use_final_probs" is true AND we reached the final-state of
/// the graph then it will include those as final-probs, else it will treat
/// all final-probs as one.
void GetBestPath(bool end_of_utterance,
Lattice *best_path) const;
/// This function outputs to "final_relative_cost", if non-NULL, a number >= 0
/// that will be close to zero if the final-probs were close to the best probs
/// active on the final frame. (the output to final_relative_cost is based on
/// the first-pass decoding). If it's close to zero (e.g. < 5, as a guess),
/// it means you reached the end of the grammar with good probability, which
/// can be taken as a good sign that the input was OK.
BaseFloat FinalRelativeCost() { return decoder_.FinalRelativeCost(); }
/// This function calls EndpointDetected from online-endpoint.h,
/// with the required arguments.
bool EndpointDetected(const OnlineEndpointConfig &config);
~SingleUtteranceGmmDecoder();
private:
bool GetGaussianPosteriors(bool end_of_utterance, GaussPost *gpost);
/// Returns true if doing a lattice rescoring pass would have any point, i.e.
/// if we have estimated fMLLR during this utterance, or if we have a
/// discriminative model that differs from the fMLLR model *and* we currently
/// have fMLLR features.
bool RescoringIsNeeded() const;
OnlineGmmDecodingConfig config_;
std::vector<int32> silence_phones_; // sorted, unique list of silence phones,
// derived from config_
const OnlineGmmDecodingModels &models_;
OnlineFeaturePipeline *feature_pipeline_; // owned here.
const OnlineGmmAdaptationState &orig_adaptation_state_;
// adaptation_state_ generally reflects the "current" state of the
// adaptation. Note: adaptation_state_.cmvn_state is just copied from
// orig_adaptation_state, the function GetAdaptationState() gets the CMVN
// state.
OnlineGmmAdaptationState adaptation_state_;
LatticeFasterOnlineDecoder decoder_;
};
/// @} End of "addtogroup onlinedecoding"
} // namespace kaldi
#endif // KALDI_ONLINE2_ONLINE_GMM_DECODING_H_