Blame view

src/doc/chain.dox 23.6 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
  // doc/chain.dox
  
  // Copyright 2015   Johns Hopkins University (author: Daniel Povey)
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  
  //  http://www.apache.org/licenses/LICENSE-2.0
  
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  namespace kaldi {
  
  /**
    \page chain 'Chain' models
  
    \section chain_intro Introduction to 'chain' models
  
    The 'chain' models are a type of DNN-HMM model, implemented using \ref dnn3 "nnet3", and differ from the
    conventional model in various ways; you can think of them as a different
    design point in the space of acoustic models.
  
     - We use a 3 times smaller frame rate at the output of the neural net,
       This significantly reduces the amount of computation required in
       test time, making real-time decoding much easier.
     - The models are trained right from the start with a sequence-level
       objective function-- namely, the log probability of the correct sequence.  It is
       essentially MMI implemented without lattices on the GPU, by doing a full
       forward-backward on a decoding graph derived from a phone n-gram language
       model.
     - Because of the reduced frame rate, we need to use unconventional
       HMM topologies (allowing the traversal of the HMM in one state).
     - We use fixed transition probabilities in the HMM, and don't train
       them (we may decide train them in future; but for the most part the neural-net
       output probabilities can do the same job as the transition probabilities,
       depending on the topology).
     - Currently, only nnet3 DNNs are supported (see \ref dnn3), and
       online decoding has not yet been implemented (we're aiming for April to June 2016).
     - Currently the results are a bit better then those of conventional
       DNN-HMMs (about 5\% relative better), but the system is about 3 times
       faster to decode; training time is probably a bit faster too, but
       we haven't compared it exactly.
  
    \section chain_scripts  Where to find scripts for the 'chain' models
  
    The current best scripts for the 'chain' models can be found in the
    Switchboard setup in egs/swbd/s5c; the script local/chain/run_tdnn_2o.sh is
    the current best one.  This is currently available in the 'chain' branch of
    the official github repository (https://github.com/kaldi-asr/kaldi.git) and
    eventually will be merged to the master.
  
    This script uses TDNNs as the neural net (we've been doing the development
    with TDNNs because they are easier to tune then LSTMs), and gives a better WER
    WER than the baseline TDNN: 11.4\%, versus 12.1\% for the best TDNN baseline
    (on the Switchboard-only portion of eval2000).
  
    \section chain_model  The chain model
  
    The chain model itself is no different from a conventional DNN-HMM, used with
    a (currently) 3-fold reduced frame rate at the output of the DNN.  The input
    features of the DNN are at the original frame rate of 100 per second; this makes
    sense because all the neural nets we are currently using (LSTMs, TDNNs) have some kind
    of recurrent connections or splicing inside them, i.e. they are not purely feedforward
    nets.
  
    The difference from a normal model is the objective function used to train it:
    instead of a frame-level objective, we use the log-probability of the correct
    phone sequence as the objective function.  The training process is quite
    similar in principle to MMI training, in which we compute numerator and
    denominator 'occupation probabilities' and the difference between the two is
    used in the derivative computation.  There is no need to normalize the DNN
    outputs to sum to one on each frame any more- such normalization makes no difference.
  
    Because of the reduced frame rate (one frame every 30 ms), we need to use a
    modified HMM topology.  We would like the HMM to be traversable in one
    transition (as opposed to the 3 transitions of a model mat the normal frame
    rate).  The currently favored topology has a state that can only occur once,
    and then another state that can appear zero or more times.  The state-clustering
    is obtained using the same procedure as for GMM-based models, although
    of course with a different topology (we convert the alignments to the new topology
    and frame-rate).
  
    \section chain_training The training procedure for 'chain' models
  
    The training procedure for chain models is a lattice-free version of
    MMI, where the denominator state posteriors are obtained by the
    forward-backward algorithm over a HMM formed from a phone-level decoding graph,
    and the numerator state posteriors are obtained by a similar forward-backward
    algorithm but limited to sequences corresponding to the transcript.
  
    For each output index of the neural net (i.e. for each pdf-id), we
    compute a derivative of of the form (numerator occupation probability -
    denominator occupation probability), and these are propagated back to the
    network.
  
  
    \subsection chain_training_denominator  The denominator FST
  
    For the denominator part of the computation we do forward-backward over a HMM.
   Actually, because we represent it as a finite state acceptor, the labels
   (pdf-ids) are associated with the arcs and not the states, so it's not really a
   HMM in the normal formulation, but it's easier think of it as a HMM because
   we use the forward-backward algorithm to get posteriors.
   In the code and scripts we refer to it as the 'denominator FST'.
  
    \subsubsection chain_training_denominator_phone_lm Phone language model for the denominator FST
  
    The first stage in constructing the denominator FST is to create a phone
    language model.  This language model is learned from the training-data phone
    alignments.  This is an un-smoothed language model, meaning that we never do
    backoff to lower order n-grams.  However, some language-model states are
    removed entirely, so transitions to those states go instead to the lower-order
    n-gram's state.   The reason we avoid smoothing is to reduce the number of
    arcs that there will be in the compiled graph after phonetic context expansion.
  
   The configuration that we settled on is to estimate a 4-gram language model,
    and to never prune LM states below trigram (so we always maintain at least a
    2-phone history).  On top of the number of states dictated by the no-prune
    trigram rule, we have a specifiable number (e.g. 2000) of 4-gram language
    model states which are to be retained (all the rest are identified with the
    corresponding trigram state), and the ones we choose to retain are determined
    in a way that maximizes the training-data likelihood.  All probabilities are
    estimated to maximize the training-data likelihood.  The reason not to prune
    the trigrams is that any sparsity of which trigrams are allowed, will tend to
    minimize the size of the compiled graph.  Note that if our phone LM was just a
    simple phone loop (i.e. a unigram), it would get expanded to triphones anyway
    due to phonetic context effects, but it would have arcs for all possible
    trigrams in it.  So any sparsity we get from using the un-pruned trigram model
    is a bonus.  Empirically, an un-smoothed trigram LM is what expands to the
    smallest possible FST; and pruning some of the trigrams, while it increases
    the size of the compiled FST, results in little or no WER improvement (at
    least on 300 hours of data expanded 3-fold with speed perturbation; on less
    data it might help).
  
    On the Switchboard setups the phone-LM perplexities for the various models we
    tried were in the range 5 to 7; the phone-LM perplexity with our chosen
    configuration (4-gram, pruned to trigram for all but 2000 states) was about 6.
    It was not the case that lower phone-LM perplexity always led to better WER
    of the trained system; as for conventional (word-based) MMI training, an
    intermediate strength of language model seemed to work best.
  
   \subsubsection chain_training_denominator_compilation  Compilation of the denominator FST
  
    The phone language model described in the previous section is expanded into a
    FST with 'pdf-ids' as the arcs, in a process that mirrors the process of
    decoding-graph compilation in normal Kaldi decoding (see \ref
    graph_recipe_test), except that there is no lexicon is involved, and at the
    end we convert the transition-ids to pdf-ids.
  
    One difference lies in how we minimize the size of the graph.  The normal
    recipe involves determinization and minimization.  We were not able to
    reduce the size of the graph using this procedure, or variants of it with
    disambiguation symbols.  Instead, our graph-minimization process can be described
    compactly as follows: "Repeat 3 times: push, minimize, reverse; push, minimize reverse.".
    'push' refers to weight-pushing; 'reverse' refers to reversing the directions of arcs, and
    swapping initial and final states.
  
  
   \subsubsection chain_training_denominator_normalization Initial and final probabilities, and 'normalization FST'
  
    The graph-creation process mentioned above naturally gives us an initial
    state, and final probabilities for each state; but these are not the ones we
    use in the forward-backward.  The reason is that these probabilities are
    applicable to utterance boundaries, but we train on split-up chunks of
    utterance of a fixed length (e.g. 1.5 seconds).  Constraining the HMM at these
    arbitrarily chosen cut points to the initial and final states is not
    appropriate.  Instead, we use initial probabilities derived from 'running the HMM' for
    a fixed number of iterations and averaging the probabilities; and final probabilities
    equal to 1.0 for each state.  We have a justification for this but don't have time to
    explain it right now.  In the denominator forward-backward process we apply these initial and
    final probabilities to the initial and final frame as part of the computation.  However, we also
    write out a version of the denominator FST that has these initial and final probabilities, and we refer to
    this as the 'normalization FST.'  (The initial probabilities are emulated using epsilon arcs, because
    FSTs do not support initial probabilities).  This 'normalization FST' will be used to add probabilities to the
   numerator FSTs in a way that we'll describe later.
  
    \subsection chain_training_numerator  Numerator FSTs
  
   As part of our preparation for the training process we produce something
   called a 'numerator FST' for each utterance.  The numerator FST encodes the
   supervision transcript, and also encodes an alignment of that transcript
   (i.e. it forces similarity to a reference alignment obtained from a baseline
   system), but it allows a little 'wiggle room' to vary from that reference.
   By default we allow a phone to occur 0.05 seconds before or after its
   begin and end position respectively, in the lattice alignment.
   Incorporating the alignment information is important because of the way we
   train not on entire utterances but on split-up fixed-length pieces of
   utterances (which, in turn, is important for GPU-based training): splitting up
   the utterance into pieces if we know where the transcript aligns.
  
   Instead of enforcing a particular pronunciation of the training data, we use as
   our reference a lattice of alternative pronunciations of the training data,
   generated by a lattice-generating decoding procedure using an
   utterance-specific graph as the decoding graph.  This generates all alignments
   of pronunciations that were within a beam of the best-scoring pronunciation.
  
    \subsubsection chain_training_numerator_splitting Splitting the numerator FSTs
  
   As mentioned, we train on fixed sized pieces of utterances (e.g. 1.5 seconds in
   length).  This requires that we split up the numerator FSTs up into fixed-size
   pieces.  This isn't hard, since the numerator FSTs (which, remember, encode
   time-alignment information), naturally have a structure where we can identify
   any FST state with a particular frame index.  Note: at the stage where we do this
   splitting, there are no costs in the numerator FST yet-- it's just viewed as
   encoding a constraint on paths-- so we do not have to make a decision how to split up the costs
  on the paths.
  
    \subsubsection chain_training_numerator_normalization Normalizing the numerator FSTs
  
   Above (\ref chain_training_denominator_compilation) we mentioned how we compute
   initial and final probabilities for the denominator FST, and how we encode
   these in a 'normalization FST'.  We compose the split-up pieces of numerator
   FST with this this 'normalization FST' to ensure that the costs from the
   denominator FST are reflected in the numerator FST.  This ensures that
   objective functions can never be positive (which makes them easier to
   interpret), and also guards against the possibility that the numerator FST
   could contain state sequences not allowed by the denominator FST, which in
   principle could allow the objective function to increase without bound.  The
   reason why this could happen is that the phone LM lacks smoothing, and is
   estimated from 1-best alignments, so the lattices could contain phone n-grams
   sequences not seen in training.
  
   It happens occasionally (but very rarely) that this normalization process
   generates an empty FST: this can occur when the lattice contains triphones that
   were not not present in the 1-best alignment used to train the phone language
   model, and does not have any alternative paths at that point in the lattice
   that could make up for the resulting 'failed' paths.  This can happen because
   the 1-best alignment and the lattice-producing alignment chose different
   pronunciations of a word.  These pieces of utterances are just discarded.
  
    \subsubsection chain_training_numerator_format Format of the numerator FSTs
  
    The numerator FSTs are weighted acceptors where the labels correspond to
    pdf-ids plus one.  We can't use pdf-ids, because they could be zero; and zero
    is treated specially (as epsilon) by OpenFst.  When we form minibatches, instead
    of storing an array of separate numerator FSTs we actually append them together to form a longer FST;
    this enables us to do a single forward-backward over all utterances in the minibatch,
    which directly computes the total numerator log-probability.  (This isn't an important
    feature, it's just a software detail, which we explain here lest it generate confusion).
  
    \subsection chain_training_splitting  Fixed-length chunks, and minibatches
  
    In order to train on minibatches, we split up our utterances into fixed-length
    chunks of speech (of length 1.5 seconds in our current scripts).  Utterances
    shorter than this are discarded; those longer, are split into chunks with
    either overlaps between the chunks, or small gaps between the chunks.  Note that
    our acoustic models typically require left or right frames for acoustic
    context; we add that, but this is separate issue; the context is added after
    the chunks are decided on.
  
    Our minibatch size is usually a power of 2, and it can be limited by GPU
    memory considerations.  Many of our example scripts use 128 chunks per
    minibatch.  The largest single consumer of GPU memory is the alpha
    probabilities in the forward-backward computation.  For instance, with 1.5
    second chunk, we have 50 time steps after the 3-fold subsampling.  In our
    Switchboard setup a typical denominator FST has 30,000 states in it.  We use
    single-precision floating point for the alphas, so the memory used in
    gigabytes is (128 * 50 * 30000 * 4) / 10^9 = 0.768G.
  
    This won't use up all the GPU memory, but there are other sources of memory,
    e.g. we keep around two copies of the nnet outputs in memory, which takes a
    fair amount of memory depending on the configuration-- e.g. replace the 30000
    above with about 10000 and it will give you the amount of memory used for one
    copy of the nnet outputs in a reasonable configuration.
  
  
    \subsection chain_training_shifting  Training on frame-shifted data
  
    In neural net training we already have ways of generating perturbed data to
    artificially increase the amount of data we train on.  Our standard nnet3
    neural-net training example scripts do time-warping of the raw audio, by
    factors of 0.9, 1.0 and 1.0, to create 3-fold augmented data.  This is
    orthogonal to the 'chain' models, and we do it (or not) just as we would for
    the baseline.  However, there is an extra way we can augment the data for the
    chain models, by shifting the frames.  The output frame rate for these models
    is one third the regular frame rate (configurable, of course), meaning we only
    evaluate nnet output at <code>t</code> values that are multiples of 3, so we
    can generate different versions of the training data by shifting the training
    examples by 0, 1 and 2 frames.  This is done automatically in the training
    script, and it's done 'on the fly' as we read the training examples from
    disk-- the program <code>nnet3-chain-copy-egs</code> has a
    <code>--frame-shift</code> option that is set by the script.  This affects how
    the number of epochs is interpreted.  If the user requests, for instance, 4
    epochs, then we actually train for 12 epochs; we just do so on 3
    differently-shifted versions of the data.  What the option
    <code>--frame-shift=t</code> option actually does is to shift the input frames
    by <code>t</code> and shift the output frames by the closest multple of 3 to
    <code>t</code>.  (In general it might not be 3, it's a configuration variable
    named <code>--frame-subsampling-factor</code>).
  
    \subsection chain_training_gpu GPU issues in training
  
   The parts of the computation that are specific to the 'chain' computation are
   the forward-backward over the numerator FST and over the denominator HMM.  The
   numerator part of this is very fast.  The denominator forward-backward takes
   quite a lot of time, because there can be a lot of arcs in the
   denominator FST (e.g. 200,000 arcs and 30,000 states in a typical Switchboard setup).
   The time taken can be almost as much as the time taken in the neural-net
   parts of the computation.  We were quite careful to ensure memory locality.
  
   The next step to further speed this up is probably to implement a pruned
   version of the forward-backward computation (like pruned Viterbi, but
   computing posteriors).  In order to get a speedup we'd have to prune away a
   very high percentage of states, because we'd need to make up for the loss of
   memory locality that pruning would bring.  In our current implementation we are
   careful to ensure that a group of GPU threads are all processing the same
   HMM-state and time, just from different chunks (we call these different
   'sequences' in the code); and we make sure that the memory locations
   corresponding to a these different sequences are all next to each other in
   memory, so the GPU can do 'consolidated memory access'.  With state-level
   pruning, since the memory access for the different sequences would no longer be
   'in sync', we would lose this advantage.  It should still be doable to get a
   pruned version of the forward-backward algorithm, though.
  
   For speed, we don't use log values in the alpha-beta computation for the
   denominator graph.  In order to keep all the numerical values in a suitable
   range, we multiply all the acoustic probabilities (exponentiated nnet outputs)
   on each frame, by an 'arbitrary value' selected to ensure that our alpha scores
   stay in a good range.  We call this an 'arbitrary value' because the algorithm
   is designed so that we could choose any value here, and it would still be
   mathematically correct.  We designate one HMM state as a 'special state', and
   the 'arbitrary constant' is chosen is the inverse of that special state's alpha
   on the previous frame.  This keeps the special state's alpha values close to
   one.  As the 'special state' we choose a state that has high probability in the
   limiting distribution of the HMM, and which can access the majority of states
   of the HMM.
  
   \section chain_decoding  Decoding with 'chain' models
  
   The decoding process with 'chain' models is exactly the same as for regular nnet3
   neural-net based models, and in fact uses the same script (steps/nnet3/decode.sh).
   There are a few configuration differences:
  
      - Firstly, the graph is built with a different and simpler topology; but this requires
        no special action by the user, as the graph-building script anyway takes the
        topology from the 'final.mdl' produced by the 'chain' training script, which
        contains the correct topology.
  
      - By default when we compile the graph, we use a 'self-loop-scale' of 0.1.
        This affects how the transition probabilities on self-loops are treated
        (it generally works better).  However, for the 'chain' models, because of
        how they were trained, we need to use exactly the same
        transition-probability scaling we trained with, which for simplicity we
        have set to 1.0.  So we supply the option <code>--self-loop-scale
        1.0</code> to the <code>utils/mkgraph.sh</code> script.
  
      - There is no 'division by the prior' necessary in these models.  So we simply
        don't set the vector of priors in the <code>.mdl</code> files; we made sure
        that the decoder just omits the division by the prior if the priors are not set.
  
      - The default acoustic scale we typically use in decoding (0.1) is not
        suitable-- for 'chain' models the optimal acoustic scale is very close to 1.
        So we supply the option <code>--acwt 1.0</code> to the script
        <code>steps/nnet3/decode.sh</code>.
  
      - The scoring scripts can only search the language-model scale in increments
        of 1, which works well in typical setups where the optimal language model scale
        is between 10 and 15, but not when the optimal language-model scale is close
        to 1 as it here.  (Note: for current purposes you can treat the language-model
        scale as the same as the inverse of the acoustic scale).  In order to
        work around this issue without changing the scoring scripts (which are
        database-specific), we supply a new option <code>--post-decode-acwt 10.0</code>
        to the script <code>steps/nnet3/decode.sh</code>,
        which scales the acoustic probabilities by 10 before dumping the lattice.
        After this, the optimal language-model scale will be around 10, which might
        be a little confusing if you are not aware of this issue, but is convenient
        for the way the scoring scripts are set up.
  
     - The default decoding and lattice beams are suitable without modification
       for the 'chain' models, once you use the <code>--acwt 1.0</code> option.
       However, they won't show the full possible speedup and you can get faster
       decoding by using slightly tighter beams.  By tightening the beam in the
       Switchboard setup we were able to get decoding time down from around 1.5
       times real time to around 0.5 times real time, with only around 0.2\%
       degradation in accuracy (this was with neural net evaluation on the CPU; on
       the GPU it would have been even faster).  Note from Dan: this is all to the best
       of my recollection as I write this; actually the degradation may have been more than
       that.  And bear in mind that this was on high-powered modern server machines
      (single-threaded).
  
   You might notice in the current example scripts that we use iVectors.  We do so
   just because they generally help a bit, and because the baseline setup we were
   comparing with, uses them.  There is no inherent connection with 'chain'
   models, and no fundamental requirement to use them.
  
  
  */
  
  }