Blame view
src/doc/chain.dox
23.6 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 |
// doc/chain.dox // Copyright 2015 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. namespace kaldi { /** \page chain 'Chain' models \section chain_intro Introduction to 'chain' models The 'chain' models are a type of DNN-HMM model, implemented using \ref dnn3 "nnet3", and differ from the conventional model in various ways; you can think of them as a different design point in the space of acoustic models. - We use a 3 times smaller frame rate at the output of the neural net, This significantly reduces the amount of computation required in test time, making real-time decoding much easier. - The models are trained right from the start with a sequence-level objective function-- namely, the log probability of the correct sequence. It is essentially MMI implemented without lattices on the GPU, by doing a full forward-backward on a decoding graph derived from a phone n-gram language model. - Because of the reduced frame rate, we need to use unconventional HMM topologies (allowing the traversal of the HMM in one state). - We use fixed transition probabilities in the HMM, and don't train them (we may decide train them in future; but for the most part the neural-net output probabilities can do the same job as the transition probabilities, depending on the topology). - Currently, only nnet3 DNNs are supported (see \ref dnn3), and online decoding has not yet been implemented (we're aiming for April to June 2016). - Currently the results are a bit better then those of conventional DNN-HMMs (about 5\% relative better), but the system is about 3 times faster to decode; training time is probably a bit faster too, but we haven't compared it exactly. \section chain_scripts Where to find scripts for the 'chain' models The current best scripts for the 'chain' models can be found in the Switchboard setup in egs/swbd/s5c; the script local/chain/run_tdnn_2o.sh is the current best one. This is currently available in the 'chain' branch of the official github repository (https://github.com/kaldi-asr/kaldi.git) and eventually will be merged to the master. This script uses TDNNs as the neural net (we've been doing the development with TDNNs because they are easier to tune then LSTMs), and gives a better WER WER than the baseline TDNN: 11.4\%, versus 12.1\% for the best TDNN baseline (on the Switchboard-only portion of eval2000). \section chain_model The chain model The chain model itself is no different from a conventional DNN-HMM, used with a (currently) 3-fold reduced frame rate at the output of the DNN. The input features of the DNN are at the original frame rate of 100 per second; this makes sense because all the neural nets we are currently using (LSTMs, TDNNs) have some kind of recurrent connections or splicing inside them, i.e. they are not purely feedforward nets. The difference from a normal model is the objective function used to train it: instead of a frame-level objective, we use the log-probability of the correct phone sequence as the objective function. The training process is quite similar in principle to MMI training, in which we compute numerator and denominator 'occupation probabilities' and the difference between the two is used in the derivative computation. There is no need to normalize the DNN outputs to sum to one on each frame any more- such normalization makes no difference. Because of the reduced frame rate (one frame every 30 ms), we need to use a modified HMM topology. We would like the HMM to be traversable in one transition (as opposed to the 3 transitions of a model mat the normal frame rate). The currently favored topology has a state that can only occur once, and then another state that can appear zero or more times. The state-clustering is obtained using the same procedure as for GMM-based models, although of course with a different topology (we convert the alignments to the new topology and frame-rate). \section chain_training The training procedure for 'chain' models The training procedure for chain models is a lattice-free version of MMI, where the denominator state posteriors are obtained by the forward-backward algorithm over a HMM formed from a phone-level decoding graph, and the numerator state posteriors are obtained by a similar forward-backward algorithm but limited to sequences corresponding to the transcript. For each output index of the neural net (i.e. for each pdf-id), we compute a derivative of of the form (numerator occupation probability - denominator occupation probability), and these are propagated back to the network. \subsection chain_training_denominator The denominator FST For the denominator part of the computation we do forward-backward over a HMM. Actually, because we represent it as a finite state acceptor, the labels (pdf-ids) are associated with the arcs and not the states, so it's not really a HMM in the normal formulation, but it's easier think of it as a HMM because we use the forward-backward algorithm to get posteriors. In the code and scripts we refer to it as the 'denominator FST'. \subsubsection chain_training_denominator_phone_lm Phone language model for the denominator FST The first stage in constructing the denominator FST is to create a phone language model. This language model is learned from the training-data phone alignments. This is an un-smoothed language model, meaning that we never do backoff to lower order n-grams. However, some language-model states are removed entirely, so transitions to those states go instead to the lower-order n-gram's state. The reason we avoid smoothing is to reduce the number of arcs that there will be in the compiled graph after phonetic context expansion. The configuration that we settled on is to estimate a 4-gram language model, and to never prune LM states below trigram (so we always maintain at least a 2-phone history). On top of the number of states dictated by the no-prune trigram rule, we have a specifiable number (e.g. 2000) of 4-gram language model states which are to be retained (all the rest are identified with the corresponding trigram state), and the ones we choose to retain are determined in a way that maximizes the training-data likelihood. All probabilities are estimated to maximize the training-data likelihood. The reason not to prune the trigrams is that any sparsity of which trigrams are allowed, will tend to minimize the size of the compiled graph. Note that if our phone LM was just a simple phone loop (i.e. a unigram), it would get expanded to triphones anyway due to phonetic context effects, but it would have arcs for all possible trigrams in it. So any sparsity we get from using the un-pruned trigram model is a bonus. Empirically, an un-smoothed trigram LM is what expands to the smallest possible FST; and pruning some of the trigrams, while it increases the size of the compiled FST, results in little or no WER improvement (at least on 300 hours of data expanded 3-fold with speed perturbation; on less data it might help). On the Switchboard setups the phone-LM perplexities for the various models we tried were in the range 5 to 7; the phone-LM perplexity with our chosen configuration (4-gram, pruned to trigram for all but 2000 states) was about 6. It was not the case that lower phone-LM perplexity always led to better WER of the trained system; as for conventional (word-based) MMI training, an intermediate strength of language model seemed to work best. \subsubsection chain_training_denominator_compilation Compilation of the denominator FST The phone language model described in the previous section is expanded into a FST with 'pdf-ids' as the arcs, in a process that mirrors the process of decoding-graph compilation in normal Kaldi decoding (see \ref graph_recipe_test), except that there is no lexicon is involved, and at the end we convert the transition-ids to pdf-ids. One difference lies in how we minimize the size of the graph. The normal recipe involves determinization and minimization. We were not able to reduce the size of the graph using this procedure, or variants of it with disambiguation symbols. Instead, our graph-minimization process can be described compactly as follows: "Repeat 3 times: push, minimize, reverse; push, minimize reverse.". 'push' refers to weight-pushing; 'reverse' refers to reversing the directions of arcs, and swapping initial and final states. \subsubsection chain_training_denominator_normalization Initial and final probabilities, and 'normalization FST' The graph-creation process mentioned above naturally gives us an initial state, and final probabilities for each state; but these are not the ones we use in the forward-backward. The reason is that these probabilities are applicable to utterance boundaries, but we train on split-up chunks of utterance of a fixed length (e.g. 1.5 seconds). Constraining the HMM at these arbitrarily chosen cut points to the initial and final states is not appropriate. Instead, we use initial probabilities derived from 'running the HMM' for a fixed number of iterations and averaging the probabilities; and final probabilities equal to 1.0 for each state. We have a justification for this but don't have time to explain it right now. In the denominator forward-backward process we apply these initial and final probabilities to the initial and final frame as part of the computation. However, we also write out a version of the denominator FST that has these initial and final probabilities, and we refer to this as the 'normalization FST.' (The initial probabilities are emulated using epsilon arcs, because FSTs do not support initial probabilities). This 'normalization FST' will be used to add probabilities to the numerator FSTs in a way that we'll describe later. \subsection chain_training_numerator Numerator FSTs As part of our preparation for the training process we produce something called a 'numerator FST' for each utterance. The numerator FST encodes the supervision transcript, and also encodes an alignment of that transcript (i.e. it forces similarity to a reference alignment obtained from a baseline system), but it allows a little 'wiggle room' to vary from that reference. By default we allow a phone to occur 0.05 seconds before or after its begin and end position respectively, in the lattice alignment. Incorporating the alignment information is important because of the way we train not on entire utterances but on split-up fixed-length pieces of utterances (which, in turn, is important for GPU-based training): splitting up the utterance into pieces if we know where the transcript aligns. Instead of enforcing a particular pronunciation of the training data, we use as our reference a lattice of alternative pronunciations of the training data, generated by a lattice-generating decoding procedure using an utterance-specific graph as the decoding graph. This generates all alignments of pronunciations that were within a beam of the best-scoring pronunciation. \subsubsection chain_training_numerator_splitting Splitting the numerator FSTs As mentioned, we train on fixed sized pieces of utterances (e.g. 1.5 seconds in length). This requires that we split up the numerator FSTs up into fixed-size pieces. This isn't hard, since the numerator FSTs (which, remember, encode time-alignment information), naturally have a structure where we can identify any FST state with a particular frame index. Note: at the stage where we do this splitting, there are no costs in the numerator FST yet-- it's just viewed as encoding a constraint on paths-- so we do not have to make a decision how to split up the costs on the paths. \subsubsection chain_training_numerator_normalization Normalizing the numerator FSTs Above (\ref chain_training_denominator_compilation) we mentioned how we compute initial and final probabilities for the denominator FST, and how we encode these in a 'normalization FST'. We compose the split-up pieces of numerator FST with this this 'normalization FST' to ensure that the costs from the denominator FST are reflected in the numerator FST. This ensures that objective functions can never be positive (which makes them easier to interpret), and also guards against the possibility that the numerator FST could contain state sequences not allowed by the denominator FST, which in principle could allow the objective function to increase without bound. The reason why this could happen is that the phone LM lacks smoothing, and is estimated from 1-best alignments, so the lattices could contain phone n-grams sequences not seen in training. It happens occasionally (but very rarely) that this normalization process generates an empty FST: this can occur when the lattice contains triphones that were not not present in the 1-best alignment used to train the phone language model, and does not have any alternative paths at that point in the lattice that could make up for the resulting 'failed' paths. This can happen because the 1-best alignment and the lattice-producing alignment chose different pronunciations of a word. These pieces of utterances are just discarded. \subsubsection chain_training_numerator_format Format of the numerator FSTs The numerator FSTs are weighted acceptors where the labels correspond to pdf-ids plus one. We can't use pdf-ids, because they could be zero; and zero is treated specially (as epsilon) by OpenFst. When we form minibatches, instead of storing an array of separate numerator FSTs we actually append them together to form a longer FST; this enables us to do a single forward-backward over all utterances in the minibatch, which directly computes the total numerator log-probability. (This isn't an important feature, it's just a software detail, which we explain here lest it generate confusion). \subsection chain_training_splitting Fixed-length chunks, and minibatches In order to train on minibatches, we split up our utterances into fixed-length chunks of speech (of length 1.5 seconds in our current scripts). Utterances shorter than this are discarded; those longer, are split into chunks with either overlaps between the chunks, or small gaps between the chunks. Note that our acoustic models typically require left or right frames for acoustic context; we add that, but this is separate issue; the context is added after the chunks are decided on. Our minibatch size is usually a power of 2, and it can be limited by GPU memory considerations. Many of our example scripts use 128 chunks per minibatch. The largest single consumer of GPU memory is the alpha probabilities in the forward-backward computation. For instance, with 1.5 second chunk, we have 50 time steps after the 3-fold subsampling. In our Switchboard setup a typical denominator FST has 30,000 states in it. We use single-precision floating point for the alphas, so the memory used in gigabytes is (128 * 50 * 30000 * 4) / 10^9 = 0.768G. This won't use up all the GPU memory, but there are other sources of memory, e.g. we keep around two copies of the nnet outputs in memory, which takes a fair amount of memory depending on the configuration-- e.g. replace the 30000 above with about 10000 and it will give you the amount of memory used for one copy of the nnet outputs in a reasonable configuration. \subsection chain_training_shifting Training on frame-shifted data In neural net training we already have ways of generating perturbed data to artificially increase the amount of data we train on. Our standard nnet3 neural-net training example scripts do time-warping of the raw audio, by factors of 0.9, 1.0 and 1.0, to create 3-fold augmented data. This is orthogonal to the 'chain' models, and we do it (or not) just as we would for the baseline. However, there is an extra way we can augment the data for the chain models, by shifting the frames. The output frame rate for these models is one third the regular frame rate (configurable, of course), meaning we only evaluate nnet output at <code>t</code> values that are multiples of 3, so we can generate different versions of the training data by shifting the training examples by 0, 1 and 2 frames. This is done automatically in the training script, and it's done 'on the fly' as we read the training examples from disk-- the program <code>nnet3-chain-copy-egs</code> has a <code>--frame-shift</code> option that is set by the script. This affects how the number of epochs is interpreted. If the user requests, for instance, 4 epochs, then we actually train for 12 epochs; we just do so on 3 differently-shifted versions of the data. What the option <code>--frame-shift=t</code> option actually does is to shift the input frames by <code>t</code> and shift the output frames by the closest multple of 3 to <code>t</code>. (In general it might not be 3, it's a configuration variable named <code>--frame-subsampling-factor</code>). \subsection chain_training_gpu GPU issues in training The parts of the computation that are specific to the 'chain' computation are the forward-backward over the numerator FST and over the denominator HMM. The numerator part of this is very fast. The denominator forward-backward takes quite a lot of time, because there can be a lot of arcs in the denominator FST (e.g. 200,000 arcs and 30,000 states in a typical Switchboard setup). The time taken can be almost as much as the time taken in the neural-net parts of the computation. We were quite careful to ensure memory locality. The next step to further speed this up is probably to implement a pruned version of the forward-backward computation (like pruned Viterbi, but computing posteriors). In order to get a speedup we'd have to prune away a very high percentage of states, because we'd need to make up for the loss of memory locality that pruning would bring. In our current implementation we are careful to ensure that a group of GPU threads are all processing the same HMM-state and time, just from different chunks (we call these different 'sequences' in the code); and we make sure that the memory locations corresponding to a these different sequences are all next to each other in memory, so the GPU can do 'consolidated memory access'. With state-level pruning, since the memory access for the different sequences would no longer be 'in sync', we would lose this advantage. It should still be doable to get a pruned version of the forward-backward algorithm, though. For speed, we don't use log values in the alpha-beta computation for the denominator graph. In order to keep all the numerical values in a suitable range, we multiply all the acoustic probabilities (exponentiated nnet outputs) on each frame, by an 'arbitrary value' selected to ensure that our alpha scores stay in a good range. We call this an 'arbitrary value' because the algorithm is designed so that we could choose any value here, and it would still be mathematically correct. We designate one HMM state as a 'special state', and the 'arbitrary constant' is chosen is the inverse of that special state's alpha on the previous frame. This keeps the special state's alpha values close to one. As the 'special state' we choose a state that has high probability in the limiting distribution of the HMM, and which can access the majority of states of the HMM. \section chain_decoding Decoding with 'chain' models The decoding process with 'chain' models is exactly the same as for regular nnet3 neural-net based models, and in fact uses the same script (steps/nnet3/decode.sh). There are a few configuration differences: - Firstly, the graph is built with a different and simpler topology; but this requires no special action by the user, as the graph-building script anyway takes the topology from the 'final.mdl' produced by the 'chain' training script, which contains the correct topology. - By default when we compile the graph, we use a 'self-loop-scale' of 0.1. This affects how the transition probabilities on self-loops are treated (it generally works better). However, for the 'chain' models, because of how they were trained, we need to use exactly the same transition-probability scaling we trained with, which for simplicity we have set to 1.0. So we supply the option <code>--self-loop-scale 1.0</code> to the <code>utils/mkgraph.sh</code> script. - There is no 'division by the prior' necessary in these models. So we simply don't set the vector of priors in the <code>.mdl</code> files; we made sure that the decoder just omits the division by the prior if the priors are not set. - The default acoustic scale we typically use in decoding (0.1) is not suitable-- for 'chain' models the optimal acoustic scale is very close to 1. So we supply the option <code>--acwt 1.0</code> to the script <code>steps/nnet3/decode.sh</code>. - The scoring scripts can only search the language-model scale in increments of 1, which works well in typical setups where the optimal language model scale is between 10 and 15, but not when the optimal language-model scale is close to 1 as it here. (Note: for current purposes you can treat the language-model scale as the same as the inverse of the acoustic scale). In order to work around this issue without changing the scoring scripts (which are database-specific), we supply a new option <code>--post-decode-acwt 10.0</code> to the script <code>steps/nnet3/decode.sh</code>, which scales the acoustic probabilities by 10 before dumping the lattice. After this, the optimal language-model scale will be around 10, which might be a little confusing if you are not aware of this issue, but is convenient for the way the scoring scripts are set up. - The default decoding and lattice beams are suitable without modification for the 'chain' models, once you use the <code>--acwt 1.0</code> option. However, they won't show the full possible speedup and you can get faster decoding by using slightly tighter beams. By tightening the beam in the Switchboard setup we were able to get decoding time down from around 1.5 times real time to around 0.5 times real time, with only around 0.2\% degradation in accuracy (this was with neural net evaluation on the CPU; on the GPU it would have been even faster). Note from Dan: this is all to the best of my recollection as I write this; actually the degradation may have been more than that. And bear in mind that this was on high-powered modern server machines (single-threaded). You might notice in the current example scripts that we use iVectors. We do so just because they generally help a bit, and because the baseline setup we were comparing with, uses them. There is no inherent connection with 'chain' models, and no fundamental requirement to use them. */ } |