chain-generic-numerator.h
9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
// chain/chain-generic-numerator.h
// Copyright 2017 Hossein Hadian
// 2018 Johns Hopkins University (Jan "Yenda" Trmal)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_CHAIN_CHAIN_GENERIC_NUMERATOR_H_
#define KALDI_CHAIN_CHAIN_GENERIC_NUMERATOR_H_
#include <vector>
#include <map>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "fstext/fstext-lib.h"
#include "tree/context-dep.h"
#include "lat/kaldi-lattice.h"
#include "matrix/kaldi-matrix.h"
#include "hmm/transition-model.h"
#include "chain/chain-supervision.h"
#include "cudamatrix/cu-matrix.h"
#include "cudamatrix/cu-array.h"
#include "chain/chain-datastruct.h"
namespace kaldi {
namespace chain {
/* This extended comment explains how end-to-end (i.e. flat-start) chain
training is done and how it is mainly different from regular chain training.
The key differnece with regular chain is that the end-to-end supervision FST
(i.e. numerator graph) can have loops and more than one final state (we
call it 'Generic' numerator in the code). This is because we do not
have any alignments so we can't split the utterances and we can't remove
the self-loops.
Of course, the end-to-end FST still has to be epsilon-free and have pdf_id+1
on its input and output labels, just like the regular supervision FST.
The end-to-end supervision (which contains the generic numerator FST's) is
created using TrainingGraphToSupervision from a training FST (i.e. an FST
created using compile-train-graphs). It is stored in the same struct as
regular supervision (i.e. chain::Supervision) but this function
sets the 'e2e' flag to true. Also the generic numerator FSTs
are stored in 'e2e_fsts' instead of 'fst'.
The TrainingGraphToSupervision function is called in nnet3-chain-e2e-get-egs
binary to create end-to-end chain egs. The only difference between a regular
and end-to-end chain example is the supervision as explained above.
class GenericNumeratorComputation is responsible for doing Forward-Backward
on a generic FST (i.e. the kind of FST we use in end-to-end chain
training). It is the same as DenominatorComputation with 2 differences:
[1] it runs on CPU
[2] it does not use leakyHMM
The F-B computation is done in log-domain.
When the 'e2e' flag of a supervision is set, the ComputeChainObjfAndDeriv
function in chain-training.cc uses GenericNumeratorComputation (instead
of NumeratorCompuation) to compute the numerator derivatives.
The implementation tries to optimize the memory transfers. The optimization
uses the observation that for each supervision graph, only very limited
number of pdfs is needed to evaluate the possible transitions from state
to state. That means that for the F-B, we don't have to transfer the whole
neural network output, we can copy only the limited set of pdfs activation
values that will be needed for F-B on the given graph.
To streamline things, in the constructor of this class, we remap the pdfs
indices to a new space and store the bookkeeping info in the index_to_pdf_
structure. This can be seen as if for each FST we create a subspace that
has only the pdfs that are needed for the given FST (possibly ordered
differently).
Morover, we optimize memory transfers. The matrix of nnet outputs can be
reshaped (viewed) as a matrix of dimensions
(frames_per_sequence) x (num_sequences * pdf_stride), where the pdf_stride
is the stride of the original matrix and pdf_stride >= num_pdfs.
When the matrix is viewed this way, it becomes obvious that the pdfs of the
k-th supervision sequence have column index k * pdf_stride + original_pdf_index
Once this is understood, the way how copy all pdfs in one shot should become
obvious.
The complete F-B is then done in this remapped space and only
when copying the activation values from the GPU memory or copying
the computed derivatives to GPU memory, we use the bookkeeping info to
map the values correctly.
*/
// This class is responsible for the forward-backward of the
// end-to-end 'supervision' (numerator) FST. This kind of FST can
// have self-loops.
// Note: An end-to-end supervision is the same as a regular supervision
// (class chain::Supervision) except the 'e2e' flag is set to true
// and the numerator FSTs are stored in 'e2e_fsts' instead of 'fst'
class GenericNumeratorComputation {
public:
/// Initializes the object.
GenericNumeratorComputation(const Supervision &supervision,
const CuMatrixBase<BaseFloat> &nnet_output);
// Does the forward-backward computation. Returns the total log-prob
// multiplied by supervision_.weight.
// In the backward computation, add (efficiently) the derivative of the
// nnet output w.r.t. the (log-prob times supervision_.weight times
// deriv_weight) to 'nnet_output_deriv'.
bool ForwardBackward(BaseFloat *total_loglike,
CuMatrixBase<BaseFloat> *nnet_output_deriv);
BaseFloat ComputeObjf();
private:
// For the remapped FSTs, copy the appropriate activations to CPU memory.
// For explanation of what remapped FST is, see the large comment in the
// beginning of the file
void CopySpecificPdfsIndirect(
const CuMatrixBase<BaseFloat> &nnet_output,
const std::vector<MatrixIndexT> &indices,
Matrix<BaseFloat> *output);
// For the remapped FSTs, copy the computed values back to gpu,
// expand to the original shape and add to the output matrix.
// For explanation of what remapped FST is, see the large comment in the
// beginning of the file.
void AddSpecificPdfsIndirect(
Matrix<BaseFloat> *logprobs,
const std::vector<MatrixIndexT> &indices,
CuMatrixBase<BaseFloat> *output);
// sets up the alpha for frame t = 0.
void AlphaFirstFrame(int seq, Matrix<BaseFloat> *alpha);
// the alpha computation for 0 < t <= supervision_.frames_per_sequence
// for some 0 <= seq < supervision_.num_sequences.
BaseFloat AlphaRemainingFrames(int seq,
const Matrix<BaseFloat> &probs,
Matrix<BaseFloat> *alpha);
// the beta computation for 0 <= t < supervision_.frames_per_sequence
// for some 0 <= seq < supervision_.num_sequences.
void BetaRemainingFrames(int32 seq,
const Matrix<BaseFloat> &probs,
const Matrix<BaseFloat> &alpha,
Matrix<BaseFloat> *beta,
Matrix<BaseFloat> *derivs);
// the beta computation for t = supervision_.frames_per_sequence
void BetaLastFrame(int seq,
const Matrix<BaseFloat> &alpha,
Matrix<BaseFloat> *beta);
// returns total prob for the given matrix alpha (assumes the alpha
// matrix was computed using AlphaFirstFrame() and AlphaRemainingFrames()
// (it's exactly like 'tot_probe_' in DenominatorComputation)
BaseFloat GetTotalProb(const Matrix<BaseFloat> &alpha);
// some checking that we can do if debug mode is activated, or on frame zero.
// Returns false if a bad problem is detected.
bool CheckValues(int32 seq,
const Matrix<BaseFloat> &probs,
const Matrix<BaseFloat> &alpha,
const Matrix<BaseFloat> &beta,
const Matrix<BaseFloat> &derivs) const;
const Supervision &supervision_;
// a reference to the nnet output.
const CuMatrixBase<BaseFloat> &nnet_output_;
int32 nnet_output_stride_; // we keep the original stride extra
// as the matrix can change before ForwardBackward
// in_transitions_ lists all the incoming transitions for
// each state of each numerator graph
// out_transitions_ does the same but for the outgoing transitions
typedef std::vector<std::vector<DenominatorGraphTransition> > TransitionMap;
std::vector<TransitionMap> in_transitions_, out_transitions_;
std::vector<MatrixIndexT> index_to_pdf_;
// final probs for each state of each numerator graph
Matrix<BaseFloat> final_probs_; // indexed by seq, state
// an offset subtracted from the logprobs of transitions out of the first
// state of each graph to help reduce numerical problems.
Vector<BaseFloat> offsets_;
};
} // namespace chain
} // namespace kaldi
#endif // KALDI_CHAIN_CHAIN_GENERIC_NUMERATOR_H_