fmpe.h
12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
// transform/fmpe.h
// Copyright 2011-2012 Yanmin Qian Johns Hopkins University (Author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_TRANSFORM_FMPE_H_
#define KALDI_TRANSFORM_FMPE_H_ 1
#include <vector>
#include "gmm/am-diag-gmm.h"
#include "gmm/mle-am-diag-gmm.h"
#include "hmm/transition-model.h"
#include "hmm/posterior.h"
namespace kaldi {
struct FmpeOptions {
// Probably the easiest place to start, to understand fMPE, is the
// paper "Improvements to fMPE for discriminative training of features".
// We are simplifying a few things here. We are getting rid of the
// "indirect differential"; we are adding a linear transform after the
// high->low dimension projection whose function is to "un-whiten" the
// transformed features (i.e. project from a nominally Gaussian-distributed
// space into our actual feature space), in order to make it unnecessary to
// take into account the per-dim variance during the update phase of fMPE;
// and the update equations are rather simpler than described in
// the paper; we take away some stuff, but add in the capability to
// do l2 regularization during the update phase.
std::string context_expansion; // This string describes the various contexts...
// the easiest way to think of it is, we first generate the high-dimensional
// features without context expansion, and we then append the left and right
// frames, and also weighted averages of further-out frames, as specified by
// this string. Suppose there are 1024 Gaussians and the feature dimension is
// 40. In the simple way to describe it, supposing there are 9 contexts (the
// central frame, the left and right frames, and 6 averages of more distant
// frames), we generate the "offset features" of dimension (1024 * 41), then
// add left and right temporal context to the high-dim features so the
// dimension is (1024 * 41 * 9), and then project down to 40, so we train a
// matrix of 40 x (1024 * 41 * 9). As described in the paper, though, we
// reorganize the computation for efficiency (it has to do with preserving
// sparsity), and we train a matrix of dimension (40 * 9) x (1024 * 41). The
// (40 x 9) -> 40 transformation, which involves time as well, is dictated by
// these contexts.
// You probably won't want to mess with this "context_expansion" string.
// The most important parameter to tune is the number of Gaussians in
// the UBM. Typically this will be in the range 300 to 1000.
BaseFloat post_scale; // Scale on the posterior component of the high-dim
// features (1 of these for every [feat-dim] of the offset features).
// Typically 5.0-- this just gives a bit more emphasis to these posteriors
// during training, like a faster learning rate.
FmpeOptions(): context_expansion("0,1.0:-1,1.0:1,1.0:-2,0.5;-3,0.5:2,0.5;3,0.5:-4,0.5;-5,0.5:4,0.5;5,0.5:-6,0.333;-7,0.333;-8,0.333:6,0.333;7,0.333;8,0.333"),
post_scale(5.0) { }
void Register(OptionsItf *opts) {
opts->Register("post-scale", &post_scale, "Scaling constant on posterior "
"element of offset features, to give it a faster learning "
"rate.");
opts->Register("context-expansion", &context_expansion, "Specifies the "
"temporal context-splicing of high-dimensional features.");
}
// We include write and read functions, since this
// object is included as a member of the fMPE object.
void Write(std::ostream &os, bool binary) const;
void Read(std::istream &is, bool binary);
};
struct FmpeUpdateOptions {
BaseFloat learning_rate; // Learning rate constant. Like inverse of E
// in the papers.
BaseFloat l2_weight; // Weight on l2 regularization term
FmpeUpdateOptions(): learning_rate(0.1), l2_weight(100.0) { }
void Register(OptionsItf *opts) {
opts->Register("learning-rate", &learning_rate,
"Learning rate constant (like inverse of E in fMPE papers)");
opts->Register("l2-weight", &l2_weight,
"Weight on l2 regularization term in objective function.");
}
};
class Fmpe;
struct FmpeStats {
FmpeStats() { };
void Init(const Fmpe &fmpe);
FmpeStats(const Fmpe &fmpe) { Init(fmpe); }
void Write(std::ostream &os, bool binary) const;
void Read(std::istream &is, bool binary, bool add = false);
SubMatrix<BaseFloat> DerivPlus() const;
SubMatrix<BaseFloat> DerivMinus() const;
/// If we're using the indirect differential, accumulates certain quantities
/// that will be used in the update phase to verify that the computation
/// of the indirect differential was done correctly
void AccumulateChecks(const MatrixBase<BaseFloat> &feats,
const MatrixBase<BaseFloat> &direct_deriv,
const MatrixBase<BaseFloat> &indirect_deriv);
void DoChecks(); // Will check that stuff cancels. Just prints
// messages for now.
private:
Matrix<BaseFloat> deriv; // contains positive and negative parts of derivatives
// separately as sub-parts of the matrix, to ensure memory locality.
// checks() is an 8 x fmpe.FeatDim() matrix that stores:
// (0-1) summed-deriv from direct, +ve and -ve part.
// (2-3) summed-deriv from indirect, +ve and -ve part.
// (4-5) (summed-deriv from direct * features), +ve and -ve part.
// (6-7) (summed-deriv from indirect * features), +ve and -ve part.
Matrix<double> checks; // contains quantities we use to check the
// indirect and direct derivatives are canceling as they should.
};
class Fmpe {
public:
Fmpe() {}
Fmpe(const DiagGmm &gmm, const FmpeOptions &config);
int32 FeatDim() const { return gmm_.Dim(); }
int32 NumGauss() const { return gmm_.NumGauss(); }
int32 NumContexts() const { return static_cast<int32>(contexts_.size()); }
// Note: this returns the number of rows and columns in projT_,
// which is the transpose of the high->intermediate dimensional
// projection matrix. This is the dimension we want for the
// stats.
int32 ProjectionTNumRows() const { return (FeatDim()+1) * NumGauss(); }
int32 ProjectionTNumCols() const { return FeatDim() * NumContexts(); }
// Computes the fMPE feature offsets and outputs them.
// You can add feat_in to this afterwards, if you want.
// Requires the Gaussian-selection info, which would normally
// be computed by a separate program-- this consists of
// lists of the top-scoring Gaussians for these features.
void ComputeFeatures(const MatrixBase<BaseFloat> &feat_in,
const std::vector<std::vector<int32> > &gselect,
Matrix<BaseFloat> *feat_out) const;
// For training-- compute the derivative w.r.t the projection matrix
// (we keep the positive and negative parts separately to help
// set the learning rates).
void AccStats(const MatrixBase<BaseFloat> &feat_in,
const std::vector<std::vector<int32> > &gselect,
const MatrixBase<BaseFloat> &direct_feat_deriv,
const MatrixBase<BaseFloat> *indirect_feat_deriv, // may be NULL
FmpeStats *stats) const;
// Note: the form on disk starts with the GMM; that way,
// the gselect program can treat the fMPE object as if it
// is a GMM.
void Write(std::ostream &os, bool binary) const;
void Read(std::istream &is, bool binary);
// Returns total objf improvement, based on linear assumption.
BaseFloat Update(const FmpeUpdateOptions &config,
const FmpeStats &stats);
private:
void SetContexts(std::string context_str);
void ComputeC(); // Computes the Cholesky factor C, from the GMM.
void ComputeStddevs();
// Constructs the high-dim features and applies the main projection matrix proj_.
void ApplyProjection(const MatrixBase<BaseFloat> &feat_in,
const std::vector<std::vector<int32> > &gselect,
MatrixBase<BaseFloat> *intermed_feat) const;
// The same in reverse, for computing derivatives.
void ApplyProjectionReverse(const MatrixBase<BaseFloat> &feat_in,
const std::vector<std::vector<int32> > &gselect,
const MatrixBase<BaseFloat> &intermed_feat_deriv,
MatrixBase<BaseFloat> *proj_deriv_plus,
MatrixBase<BaseFloat> *proj_deriv_minus) const;
// Applies the temporal context splicing from the intermediate
// features-- adds the result to feat_out which at this point
// will typically be zero.
void ApplyContext(const MatrixBase<BaseFloat> &intermed_feat,
MatrixBase<BaseFloat> *feat_out) const;
// This is as ApplyContext but for back-propagating the derivative.
// Result is added to intermediate_feat_deriv which at this point will
// typically be zero.
void ApplyContextReverse(const MatrixBase<BaseFloat> &feat_deriv,
MatrixBase<BaseFloat> *intermed_feat_deriv) const;
// Multiplies the feature offsets by the Cholesky matrix C.
void ApplyC(MatrixBase<BaseFloat> *feat_out, bool reverse = false) const;
// For computing derivatives-- multiply the derivatives by C^T,
// which is the "reverse" of the forward pass of multiplying
// by C (this is how derivatives behave...)
void ApplyCReverse(MatrixBase<BaseFloat> *deriv) const { ApplyC(deriv, true); }
DiagGmm gmm_; // The GMM used to get posteriors.
FmpeOptions config_;
Matrix<BaseFloat> stddevs_; // The standard deviations of the
// variances of the GMM -- computed to avoid taking a square root
// in the fMPE computation. Derived variable-- not stored on
// disk.
Matrix<BaseFloat> projT_; // The transpose of the projection matrix;
// this is of dimension
// (NumGauss() * (FeatDim()+1)) * (FeatDim() * NumContexts()).
TpMatrix<BaseFloat> C_; // Cholesky factor of the variance Sigma of
// features around their mean (as estimated from GMM)... applied
// to fMPE offset just before we add it to the features. This allows
// us to simplify the fMPE update and not have to worry about
// the features having non-unit variance, and what effect this should
// have on the learning rate..
// The following variable dictates how we use temporal context.
// e.g. contexts = { { (0, 1.0) }, { (-1, 1.0) }, { (1, 1.0) },
// { (-2, 0.5 ), (-3, 0.5) }, ... }
std::vector<std::vector<std::pair<int32, BaseFloat> > > contexts_;
};
/// Computes derivatives of the likelihood of these states (weighted),
/// w.r.t. the feature values. Used in fMPE training. Note, the
/// weights "posterior" may be positive or negative-- for MMI, MPE,
/// etc., they will typically be of both signs. Will resize "deriv".
/// Returns the sum of (GMM likelihood * weight), which may be used
/// as an approximation to the objective function.
/// Last two parameters are optional. See GetStatsDerivative() for
/// or fMPE paper (ICASSP, 2005) more info on indirect derivative.
/// Caution: if you supply the last two parameters, this function only
/// works in the MMI case as it assumes the stats with positive weight
/// are numerator == ml stats-- this is only the same thing in the MMI
/// case, not fMPE.
BaseFloat ComputeAmGmmFeatureDeriv(const AmDiagGmm &am_gmm,
const TransitionModel &trans_model,
const Posterior &posterior,
const MatrixBase<BaseFloat> &features,
Matrix<BaseFloat> *direct_deriv,
const AccumAmDiagGmm *model_diff = NULL,
Matrix<BaseFloat> *indirect_deriv = NULL);
} // End namespace kaldi
#endif