Blame view
src/chain/chain-denominator.h
13.9 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 |
// chain/chain-denominator.h // Copyright 2015 Johns Hopkins University (Author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_CHAIN_CHAIN_DENOMINATOR_H_ #define KALDI_CHAIN_CHAIN_DENOMINATOR_H_ #include <vector> #include <map> #include "base/kaldi-common.h" #include "util/common-utils.h" #include "fstext/fstext-lib.h" #include "tree/context-dep.h" #include "lat/kaldi-lattice.h" #include "matrix/kaldi-matrix.h" #include "hmm/transition-model.h" #include "cudamatrix/cu-matrix.h" #include "cudamatrix/cu-array.h" #include "chain/chain-den-graph.h" #include "chain/chain-training.h" namespace kaldi { namespace chain { /* This extended comment describes how we implement forward-backward without log and without overflow, and also the leaky-HMM idea. We'll start by establishing the notation for conventional forward-backward, then add the 'arbitrary-scale' concept that prevents overflow, and then add the 'leaky-hmm' concept. All this is done in parallel over multiple sequences, but the computations are independent over the separate sequences, so we won't introduce any notation or index for the sequence; we'll just explain it for one sequence. Suppose we have I hmm-states, numbered i = 0 ... I-1 (we'll use i and j for hmm-state indexes). Let foll(i) give a list of arcs leaving state i, and pred(i) give a list of arcs entering state i, and we'll use notation like: for (j, p, n) in foll(i): for iterating over those arcs, where in this case j is the destination-state, p is the transition-probability of the arc and n is the pdf-id index. We can then look up the emission probability as x(t, n) for some frame 0 <= t < T. ** Version 1 of the computation (naive version) ** * Forward computation (version 1) In the forward computation we're computing alpha(i, t) for 0 <= t <= T): - For the first frame, set alpha(0, i) = init(i), where init(i) is the initial-probabilitiy from state i. # in our framework these are obtained # by running the HMM for a while and getting an averaged occupation # probability, and using this as an initial-prob, since the boundaries of # chunks don't really correspond to utterance boundaries in general.] - For t = 1 ... T: for i = 0 ... I-1: alpha(t, i) = 0 for (j, p, n) in pred(i): # note: j is preceding-state. alpha(t, i) += x(t-1, n) * alpha(t-1, j) * p. - total-prob = \sum_i alpha(T, i). # note, we take the final-probs of all states # to be 1.0. * Backward computation (version 1) And now for the backward computation. Contrary to tradition, we include the inverse of the total-prob as a factor in the betas. This is both more convenient (it simplifies the way we obtain posteriors), and makes the algorithm more generalizable as all the beta quantities can be interpreted as the partial derivative of the overall logprob with respect to their corresponding alpha. In forward backward notation, gamma is normally used for state-level occupation probabilities, but what we care about here is pdf-id-level occupation probabilities (i.e. the partial derivative of the overall logprob w.r.t. the logs of the x(t, n) quantities), so we use gamma for that. - for the final frame: for each i, beta(T, i) = 1 / total-prob. - for t = T-1 ... 0: for i = 0 ... I-1: beta(t, i) = 0 for (j, p, n) in foll(i): # note: j is following-state. beta(t, i) += x(t, n) * beta(t+1, j) * p. gamma(t, n) += alpha(t, i) * x(t, n) * beta(t+1, j) * p. ** Version 2 of the computation (renormalized version) ** Version 1 of the algorithm is susceptible to numeric underflow and overflow, due to the limited range of IEEE floating-point exponents. Define tot-alpha(t) = \sum_i alpha(t, i). Then the renormalized version of the computation is as above, except whenever the quantity x(t, n) appears, we replace it with x(t, n) / tot-alpha(t). In the algorithm we refer to 1.0 / tot-alpha(t) as 'arbitrary_scale', because mathematically we can use any value here as long as we are consistent and the value only varies with t and not with n; we'll always get the same posteriors (gamma). When the algorithm outputs log(total-prob) as the total log-probability of the HMM, we have to instead return the expression: log(total-prob) + \sum_{t=0}^{T-1} \log tot-alpha(t). to correct for the scaling of the x values. The algorithm is still vulnerable to overflow in the beta computation because it's possible that the dominant path could have a very tiny alpha. However, once we introduce the leaky-HMM idea (below), this problem will disappear. ** Version 3 of the computation (leaky-HMM version) ** The leaky-HMM idea is intended to improve generalization by allowing paths other than those explicitly allowed by the FST we compiled. Another way to look at it is as a way of hedging our bets about where we split the utterance, so it's as we're marginalizing over different splits of the utterance. You could also think of it as a modification of the FST so that there is an epsilon transition from each state to a newly added state, with probability one, and then an epsilon transition from the newly added state to each state with probability leaky-hmm-prob * init(i) [except we need a mechanism so that no more than two epsilon transitions can be taken per frame- this would involve creating two copies of the states] Recall that we mentioned that init(i) is the initial-probability of HMM-state i, but these are obtained in such a way that they can be treated as priors, or average occupation-probabilities. Anyway, the way we formulate leaky-hmm is as follows: * Forward computation (version 3) Let leaky-hmm-prob be a constant defined by the user, with 0.1 being a typical value. It defines how much probability we give to the 'leaky' transitions. - For frame 0, set alpha(0, i) = init(i). - For 0 <= t <= T, define tot-alpha(t) = \sum_i alpha(t, i). - For 0 <= t <= T, define alpha'(t, i) = alpha(t, i) + tot-alpha(t) * leaky-hmm-prob * init(i). - For 1 <= t <= T, the computation of alpha(t, i) is as before except we use the previous frame's alpha' instead of alpha. That is: alpha(t, i) = 0 for (j, p, n) in pred(i): # note: j is preceding-state. alpha(t, i) += alpha'(t-1, j) * p * x(t-1, n) / tot-alpha(t-1) - total-prob = \sum_i alpha'(T, i) The corrected log-prob that we return from the algorithm will be (total-prob + \sum_{t=0}^{T-1} \log tot-alpha(t)). * Backward computation (version 3) The backward computation is as follows. It is fairly straightforward to derive if you think of it as an instance of backprop where beta, tot-beta and beta' are the partial derivatives of the output log-prob w.r.t. the corresponding alpha, tot-alpha and alpha' quantities. Note, tot-beta is not really the sum of the betas as its name might suggest, it's just the derivative w.r.t. tot-alpha. - beta'(T, i) = 1 / total-prob. - for 0 <= t <= T, define tot-beta(t) = leaky-hmm-prob * \sum_i init(i) * beta'(t, i) - for 0 <= t <= T, define beta(t, i) = beta'(t, i) + tot-beta(t). - for 0 <= t < T, we compute beta'(t, i) and update gamma(t, n) as follows: for 0 <= i < I: beta'(t, i) = 0 for (j, p, n) in foll(i): # note: j is following-state. beta'(t, i) += beta(t+1, j) * p * x(t, n) / tot-alpha(t) gamma(t, n) += alpha'(t, i) * beta(t+1, j) * p * x(t, n) / tot-alpha(t) Note: in the code, the tot-alpha and tot-beta quantities go in the same memory location that the corresponding alpha and beta for state I would go. */ // This does forward-backward in parallel on a number of sequences, using a // single HMM. class DenominatorComputation { public: /* Constructor. 'nnet_output' is the raw nnet output (which we'll treat as pseudo-log-likelihoods). @param [in] opts The options. @param [in] graph The HMM that we use for the denominator (like a decoding graph, with pdf-ids on the transitions). @param [in] num_sequences The number of separate time sequences (all of the same length) that we are working with. Must divide nnet_output.NumRows(). @param [in] nnet_output The output of the neural network for this minibatch. The rows must be ordered as (first frame of all sequences) (second frame of all sequences), etc. */ DenominatorComputation(const ChainTrainingOptions &opts, const DenominatorGraph &den_graph, int32 num_sequences, const CuMatrixBase<BaseFloat> &nnet_output); // Does the forward computation, and returns the total log-like summed over // all sequences. You will have to scale this by any supervision weighting // factor, manually. Note: this log-like will be negated before it // is added into the objective function, since this is the denominator // computation. BaseFloat Forward(); // this adds deriv_weight times (the derivative of the log-prob w.r.t. the // nnet output), to 'nnet_output_deriv'. Note: normally, deriv_weight // will be -1, or some other negative number if we are doing data weighting. // returns true if everything seemed OK, false if a failure was detected. bool Backward(BaseFloat deriv_weight, CuMatrixBase<BaseFloat> *nnet_output_deriv); private: // Defining this constant as an enum is easier. it controls a memory/speed // tradeoff, determining how many frames' worth of the transposed derivative // we store at a time. It's not very critical; the only disadvantage from // setting it small is that we have to invoke an AddMat kernel more times. enum { kMaxDerivTimeSteps = 8 }; // sets up the alpha for frame t = 0. void AlphaFirstFrame(); // the alpha computation for some 0 < t <= num_time_steps_. void AlphaGeneralFrame(int32 t); // does the 'alpha-dash' computation for time t. this relates to // 'leaky hmm'. void AlphaDash(int32 t); // done after all the alphas, this function computes and returns the total // log-likelihood summed over all the sequences, and sets tot_prob_ (if we're // doing correction) log_correction_term_. Note, this won't be scaled by // 'deriv_scale' (which of course we haven't seen by the time this is called, // from the Forward() computation). BaseFloat ComputeTotLogLike(); void BetaDashLastFrame(); // beta computation for 0 <= beta < num_time_steps_. void BetaDashGeneralFrame(int32 t); // compute the beta quantity from the beta-dash quantity (relates to leaky hmm). void Beta(int32 t); // some checking that we can do if debug mode is activated, or on frame zero. // Sets ok_ to false if a bad problem is detected. void BetaGeneralFrameDebug(int32 t); const ChainTrainingOptions &opts_; const DenominatorGraph &den_graph_; // number of separate frame sequences int32 num_sequences_; // number of frames per sequence. nnet_output_.NumRows() equals // num_sequences_ * frames_per_sequence. int32 frames_per_sequence_; // The transpose of the exp() of the nnet output (the transpose is more // convenient for memory locality, and the exp() avoids us having to // exponentiate in the forward-backward). // // The row-index is the pdf-id; and the column index equals (frame_index * // num_sequences + sequence_index). CuMatrix<BaseFloat> exp_nnet_output_transposed_; // the derivs w.r.t. the nnet outputs (transposed) CuMatrix<BaseFloat> nnet_output_deriv_transposed_; // the (temporarily) alpha and (more permanently) alpha-dash probabilities; // dimension is (frames_per_sequence + 1) by (num-hmm-states * num-sequences + // num_sequences). Note, they are not logs. The last 'num_sequences' // columns, where the alpha for the state indexed 'num_hmm_states' would live, // are for the alpha-sums, which relates to leaky HMM. CuMatrix<BaseFloat> alpha_; // the beta (also beta-dash) probabilities (rolling buffer); dimension is 2 * // (num-hmm-states * num-sequences + num_sequences). [the last // 'num_sequences' columns are for the beta-sums, which relates to leaky HMM.] // Note: for efficiency and to simplify the equations, these are actually the // beta / tot_prob_. CuMatrix<BaseFloat> beta_; // the total probability for each sequence, excluding the product of // correction terms. [the correction terms refer to the fact that we multiply // on each frame by 1/alpha of hmm-state 0 of the previous frame.]. // After the correction terms the total probability is fairly close to 1, // which is why we can store it as non-log. CuVector<BaseFloat> tot_prob_; // the log of tot_prob_. CuVector<BaseFloat> tot_log_prob_; // the log of the total correction term for each sequence, which is the // product of the alpha-sums [used in the leaky-hmm computation] over all the // frames. The 'correction terms' are terms that we divide the alphas and // betas by in order to keep them in a good dynamic range. The product of // them must be included in the total likelihood. CuVector<BaseFloat> log_correction_term_; bool ok_; }; } // namespace chain } // namespace kaldi #endif // KALDI_CHAIN_CHAIN_DENOMINATOR_H_ |