chain.dox 23.6 KB
// doc/chain.dox

// Copyright 2015   Johns Hopkins University (author: Daniel Povey)

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

//  http://www.apache.org/licenses/LICENSE-2.0

// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

namespace kaldi {

/**
  \page chain 'Chain' models

  \section chain_intro Introduction to 'chain' models

  The 'chain' models are a type of DNN-HMM model, implemented using \ref dnn3 "nnet3", and differ from the
  conventional model in various ways; you can think of them as a different
  design point in the space of acoustic models.

   - We use a 3 times smaller frame rate at the output of the neural net,
     This significantly reduces the amount of computation required in
     test time, making real-time decoding much easier.
   - The models are trained right from the start with a sequence-level
     objective function-- namely, the log probability of the correct sequence.  It is
     essentially MMI implemented without lattices on the GPU, by doing a full
     forward-backward on a decoding graph derived from a phone n-gram language
     model.
   - Because of the reduced frame rate, we need to use unconventional
     HMM topologies (allowing the traversal of the HMM in one state).
   - We use fixed transition probabilities in the HMM, and don't train
     them (we may decide train them in future; but for the most part the neural-net
     output probabilities can do the same job as the transition probabilities,
     depending on the topology).
   - Currently, only nnet3 DNNs are supported (see \ref dnn3), and
     online decoding has not yet been implemented (we're aiming for April to June 2016).
   - Currently the results are a bit better then those of conventional
     DNN-HMMs (about 5\% relative better), but the system is about 3 times
     faster to decode; training time is probably a bit faster too, but
     we haven't compared it exactly.

  \section chain_scripts  Where to find scripts for the 'chain' models

  The current best scripts for the 'chain' models can be found in the
  Switchboard setup in egs/swbd/s5c; the script local/chain/run_tdnn_2o.sh is
  the current best one.  This is currently available in the 'chain' branch of
  the official github repository (https://github.com/kaldi-asr/kaldi.git) and
  eventually will be merged to the master.

  This script uses TDNNs as the neural net (we've been doing the development
  with TDNNs because they are easier to tune then LSTMs), and gives a better WER
  WER than the baseline TDNN: 11.4\%, versus 12.1\% for the best TDNN baseline
  (on the Switchboard-only portion of eval2000).

  \section chain_model  The chain model

  The chain model itself is no different from a conventional DNN-HMM, used with
  a (currently) 3-fold reduced frame rate at the output of the DNN.  The input
  features of the DNN are at the original frame rate of 100 per second; this makes
  sense because all the neural nets we are currently using (LSTMs, TDNNs) have some kind
  of recurrent connections or splicing inside them, i.e. they are not purely feedforward
  nets.

  The difference from a normal model is the objective function used to train it:
  instead of a frame-level objective, we use the log-probability of the correct
  phone sequence as the objective function.  The training process is quite
  similar in principle to MMI training, in which we compute numerator and
  denominator 'occupation probabilities' and the difference between the two is
  used in the derivative computation.  There is no need to normalize the DNN
  outputs to sum to one on each frame any more- such normalization makes no difference.

  Because of the reduced frame rate (one frame every 30 ms), we need to use a
  modified HMM topology.  We would like the HMM to be traversable in one
  transition (as opposed to the 3 transitions of a model mat the normal frame
  rate).  The currently favored topology has a state that can only occur once,
  and then another state that can appear zero or more times.  The state-clustering
  is obtained using the same procedure as for GMM-based models, although
  of course with a different topology (we convert the alignments to the new topology
  and frame-rate).

  \section chain_training The training procedure for 'chain' models

  The training procedure for chain models is a lattice-free version of
  MMI, where the denominator state posteriors are obtained by the
  forward-backward algorithm over a HMM formed from a phone-level decoding graph,
  and the numerator state posteriors are obtained by a similar forward-backward
  algorithm but limited to sequences corresponding to the transcript.

  For each output index of the neural net (i.e. for each pdf-id), we
  compute a derivative of of the form (numerator occupation probability -
  denominator occupation probability), and these are propagated back to the
  network.


  \subsection chain_training_denominator  The denominator FST

  For the denominator part of the computation we do forward-backward over a HMM.
 Actually, because we represent it as a finite state acceptor, the labels
 (pdf-ids) are associated with the arcs and not the states, so it's not really a
 HMM in the normal formulation, but it's easier think of it as a HMM because
 we use the forward-backward algorithm to get posteriors.
 In the code and scripts we refer to it as the 'denominator FST'.

  \subsubsection chain_training_denominator_phone_lm Phone language model for the denominator FST

  The first stage in constructing the denominator FST is to create a phone
  language model.  This language model is learned from the training-data phone
  alignments.  This is an un-smoothed language model, meaning that we never do
  backoff to lower order n-grams.  However, some language-model states are
  removed entirely, so transitions to those states go instead to the lower-order
  n-gram's state.   The reason we avoid smoothing is to reduce the number of
  arcs that there will be in the compiled graph after phonetic context expansion.

 The configuration that we settled on is to estimate a 4-gram language model,
  and to never prune LM states below trigram (so we always maintain at least a
  2-phone history).  On top of the number of states dictated by the no-prune
  trigram rule, we have a specifiable number (e.g. 2000) of 4-gram language
  model states which are to be retained (all the rest are identified with the
  corresponding trigram state), and the ones we choose to retain are determined
  in a way that maximizes the training-data likelihood.  All probabilities are
  estimated to maximize the training-data likelihood.  The reason not to prune
  the trigrams is that any sparsity of which trigrams are allowed, will tend to
  minimize the size of the compiled graph.  Note that if our phone LM was just a
  simple phone loop (i.e. a unigram), it would get expanded to triphones anyway
  due to phonetic context effects, but it would have arcs for all possible
  trigrams in it.  So any sparsity we get from using the un-pruned trigram model
  is a bonus.  Empirically, an un-smoothed trigram LM is what expands to the
  smallest possible FST; and pruning some of the trigrams, while it increases
  the size of the compiled FST, results in little or no WER improvement (at
  least on 300 hours of data expanded 3-fold with speed perturbation; on less
  data it might help).

  On the Switchboard setups the phone-LM perplexities for the various models we
  tried were in the range 5 to 7; the phone-LM perplexity with our chosen
  configuration (4-gram, pruned to trigram for all but 2000 states) was about 6.
  It was not the case that lower phone-LM perplexity always led to better WER
  of the trained system; as for conventional (word-based) MMI training, an
  intermediate strength of language model seemed to work best.

 \subsubsection chain_training_denominator_compilation  Compilation of the denominator FST

  The phone language model described in the previous section is expanded into a
  FST with 'pdf-ids' as the arcs, in a process that mirrors the process of
  decoding-graph compilation in normal Kaldi decoding (see \ref
  graph_recipe_test), except that there is no lexicon is involved, and at the
  end we convert the transition-ids to pdf-ids.

  One difference lies in how we minimize the size of the graph.  The normal
  recipe involves determinization and minimization.  We were not able to
  reduce the size of the graph using this procedure, or variants of it with
  disambiguation symbols.  Instead, our graph-minimization process can be described
  compactly as follows: "Repeat 3 times: push, minimize, reverse; push, minimize reverse.".
  'push' refers to weight-pushing; 'reverse' refers to reversing the directions of arcs, and
  swapping initial and final states.


 \subsubsection chain_training_denominator_normalization Initial and final probabilities, and 'normalization FST'

  The graph-creation process mentioned above naturally gives us an initial
  state, and final probabilities for each state; but these are not the ones we
  use in the forward-backward.  The reason is that these probabilities are
  applicable to utterance boundaries, but we train on split-up chunks of
  utterance of a fixed length (e.g. 1.5 seconds).  Constraining the HMM at these
  arbitrarily chosen cut points to the initial and final states is not
  appropriate.  Instead, we use initial probabilities derived from 'running the HMM' for
  a fixed number of iterations and averaging the probabilities; and final probabilities
  equal to 1.0 for each state.  We have a justification for this but don't have time to
  explain it right now.  In the denominator forward-backward process we apply these initial and
  final probabilities to the initial and final frame as part of the computation.  However, we also
  write out a version of the denominator FST that has these initial and final probabilities, and we refer to
  this as the 'normalization FST.'  (The initial probabilities are emulated using epsilon arcs, because
  FSTs do not support initial probabilities).  This 'normalization FST' will be used to add probabilities to the
 numerator FSTs in a way that we'll describe later.

  \subsection chain_training_numerator  Numerator FSTs

 As part of our preparation for the training process we produce something
 called a 'numerator FST' for each utterance.  The numerator FST encodes the
 supervision transcript, and also encodes an alignment of that transcript
 (i.e. it forces similarity to a reference alignment obtained from a baseline
 system), but it allows a little 'wiggle room' to vary from that reference.
 By default we allow a phone to occur 0.05 seconds before or after its
 begin and end position respectively, in the lattice alignment.
 Incorporating the alignment information is important because of the way we
 train not on entire utterances but on split-up fixed-length pieces of
 utterances (which, in turn, is important for GPU-based training): splitting up
 the utterance into pieces if we know where the transcript aligns.

 Instead of enforcing a particular pronunciation of the training data, we use as
 our reference a lattice of alternative pronunciations of the training data,
 generated by a lattice-generating decoding procedure using an
 utterance-specific graph as the decoding graph.  This generates all alignments
 of pronunciations that were within a beam of the best-scoring pronunciation.

  \subsubsection chain_training_numerator_splitting Splitting the numerator FSTs

 As mentioned, we train on fixed sized pieces of utterances (e.g. 1.5 seconds in
 length).  This requires that we split up the numerator FSTs up into fixed-size
 pieces.  This isn't hard, since the numerator FSTs (which, remember, encode
 time-alignment information), naturally have a structure where we can identify
 any FST state with a particular frame index.  Note: at the stage where we do this
 splitting, there are no costs in the numerator FST yet-- it's just viewed as
 encoding a constraint on paths-- so we do not have to make a decision how to split up the costs
on the paths.

  \subsubsection chain_training_numerator_normalization Normalizing the numerator FSTs

 Above (\ref chain_training_denominator_compilation) we mentioned how we compute
 initial and final probabilities for the denominator FST, and how we encode
 these in a 'normalization FST'.  We compose the split-up pieces of numerator
 FST with this this 'normalization FST' to ensure that the costs from the
 denominator FST are reflected in the numerator FST.  This ensures that
 objective functions can never be positive (which makes them easier to
 interpret), and also guards against the possibility that the numerator FST
 could contain state sequences not allowed by the denominator FST, which in
 principle could allow the objective function to increase without bound.  The
 reason why this could happen is that the phone LM lacks smoothing, and is
 estimated from 1-best alignments, so the lattices could contain phone n-grams
 sequences not seen in training.

 It happens occasionally (but very rarely) that this normalization process
 generates an empty FST: this can occur when the lattice contains triphones that
 were not not present in the 1-best alignment used to train the phone language
 model, and does not have any alternative paths at that point in the lattice
 that could make up for the resulting 'failed' paths.  This can happen because
 the 1-best alignment and the lattice-producing alignment chose different
 pronunciations of a word.  These pieces of utterances are just discarded.

  \subsubsection chain_training_numerator_format Format of the numerator FSTs

  The numerator FSTs are weighted acceptors where the labels correspond to
  pdf-ids plus one.  We can't use pdf-ids, because they could be zero; and zero
  is treated specially (as epsilon) by OpenFst.  When we form minibatches, instead
  of storing an array of separate numerator FSTs we actually append them together to form a longer FST;
  this enables us to do a single forward-backward over all utterances in the minibatch,
  which directly computes the total numerator log-probability.  (This isn't an important
  feature, it's just a software detail, which we explain here lest it generate confusion).

  \subsection chain_training_splitting  Fixed-length chunks, and minibatches

  In order to train on minibatches, we split up our utterances into fixed-length
  chunks of speech (of length 1.5 seconds in our current scripts).  Utterances
  shorter than this are discarded; those longer, are split into chunks with
  either overlaps between the chunks, or small gaps between the chunks.  Note that
  our acoustic models typically require left or right frames for acoustic
  context; we add that, but this is separate issue; the context is added after
  the chunks are decided on.

  Our minibatch size is usually a power of 2, and it can be limited by GPU
  memory considerations.  Many of our example scripts use 128 chunks per
  minibatch.  The largest single consumer of GPU memory is the alpha
  probabilities in the forward-backward computation.  For instance, with 1.5
  second chunk, we have 50 time steps after the 3-fold subsampling.  In our
  Switchboard setup a typical denominator FST has 30,000 states in it.  We use
  single-precision floating point for the alphas, so the memory used in
  gigabytes is (128 * 50 * 30000 * 4) / 10^9 = 0.768G.

  This won't use up all the GPU memory, but there are other sources of memory,
  e.g. we keep around two copies of the nnet outputs in memory, which takes a
  fair amount of memory depending on the configuration-- e.g. replace the 30000
  above with about 10000 and it will give you the amount of memory used for one
  copy of the nnet outputs in a reasonable configuration.


  \subsection chain_training_shifting  Training on frame-shifted data

  In neural net training we already have ways of generating perturbed data to
  artificially increase the amount of data we train on.  Our standard nnet3
  neural-net training example scripts do time-warping of the raw audio, by
  factors of 0.9, 1.0 and 1.0, to create 3-fold augmented data.  This is
  orthogonal to the 'chain' models, and we do it (or not) just as we would for
  the baseline.  However, there is an extra way we can augment the data for the
  chain models, by shifting the frames.  The output frame rate for these models
  is one third the regular frame rate (configurable, of course), meaning we only
  evaluate nnet output at <code>t</code> values that are multiples of 3, so we
  can generate different versions of the training data by shifting the training
  examples by 0, 1 and 2 frames.  This is done automatically in the training
  script, and it's done 'on the fly' as we read the training examples from
  disk-- the program <code>nnet3-chain-copy-egs</code> has a
  <code>--frame-shift</code> option that is set by the script.  This affects how
  the number of epochs is interpreted.  If the user requests, for instance, 4
  epochs, then we actually train for 12 epochs; we just do so on 3
  differently-shifted versions of the data.  What the option
  <code>--frame-shift=t</code> option actually does is to shift the input frames
  by <code>t</code> and shift the output frames by the closest multple of 3 to
  <code>t</code>.  (In general it might not be 3, it's a configuration variable
  named <code>--frame-subsampling-factor</code>).

  \subsection chain_training_gpu GPU issues in training

 The parts of the computation that are specific to the 'chain' computation are
 the forward-backward over the numerator FST and over the denominator HMM.  The
 numerator part of this is very fast.  The denominator forward-backward takes
 quite a lot of time, because there can be a lot of arcs in the
 denominator FST (e.g. 200,000 arcs and 30,000 states in a typical Switchboard setup).
 The time taken can be almost as much as the time taken in the neural-net
 parts of the computation.  We were quite careful to ensure memory locality.

 The next step to further speed this up is probably to implement a pruned
 version of the forward-backward computation (like pruned Viterbi, but
 computing posteriors).  In order to get a speedup we'd have to prune away a
 very high percentage of states, because we'd need to make up for the loss of
 memory locality that pruning would bring.  In our current implementation we are
 careful to ensure that a group of GPU threads are all processing the same
 HMM-state and time, just from different chunks (we call these different
 'sequences' in the code); and we make sure that the memory locations
 corresponding to a these different sequences are all next to each other in
 memory, so the GPU can do 'consolidated memory access'.  With state-level
 pruning, since the memory access for the different sequences would no longer be
 'in sync', we would lose this advantage.  It should still be doable to get a
 pruned version of the forward-backward algorithm, though.

 For speed, we don't use log values in the alpha-beta computation for the
 denominator graph.  In order to keep all the numerical values in a suitable
 range, we multiply all the acoustic probabilities (exponentiated nnet outputs)
 on each frame, by an 'arbitrary value' selected to ensure that our alpha scores
 stay in a good range.  We call this an 'arbitrary value' because the algorithm
 is designed so that we could choose any value here, and it would still be
 mathematically correct.  We designate one HMM state as a 'special state', and
 the 'arbitrary constant' is chosen is the inverse of that special state's alpha
 on the previous frame.  This keeps the special state's alpha values close to
 one.  As the 'special state' we choose a state that has high probability in the
 limiting distribution of the HMM, and which can access the majority of states
 of the HMM.

 \section chain_decoding  Decoding with 'chain' models

 The decoding process with 'chain' models is exactly the same as for regular nnet3
 neural-net based models, and in fact uses the same script (steps/nnet3/decode.sh).
 There are a few configuration differences:

    - Firstly, the graph is built with a different and simpler topology; but this requires
      no special action by the user, as the graph-building script anyway takes the
      topology from the 'final.mdl' produced by the 'chain' training script, which
      contains the correct topology.

    - By default when we compile the graph, we use a 'self-loop-scale' of 0.1.
      This affects how the transition probabilities on self-loops are treated
      (it generally works better).  However, for the 'chain' models, because of
      how they were trained, we need to use exactly the same
      transition-probability scaling we trained with, which for simplicity we
      have set to 1.0.  So we supply the option <code>--self-loop-scale
      1.0</code> to the <code>utils/mkgraph.sh</code> script.

    - There is no 'division by the prior' necessary in these models.  So we simply
      don't set the vector of priors in the <code>.mdl</code> files; we made sure
      that the decoder just omits the division by the prior if the priors are not set.

    - The default acoustic scale we typically use in decoding (0.1) is not
      suitable-- for 'chain' models the optimal acoustic scale is very close to 1.
      So we supply the option <code>--acwt 1.0</code> to the script
      <code>steps/nnet3/decode.sh</code>.

    - The scoring scripts can only search the language-model scale in increments
      of 1, which works well in typical setups where the optimal language model scale
      is between 10 and 15, but not when the optimal language-model scale is close
      to 1 as it here.  (Note: for current purposes you can treat the language-model
      scale as the same as the inverse of the acoustic scale).  In order to
      work around this issue without changing the scoring scripts (which are
      database-specific), we supply a new option <code>--post-decode-acwt 10.0</code>
      to the script <code>steps/nnet3/decode.sh</code>,
      which scales the acoustic probabilities by 10 before dumping the lattice.
      After this, the optimal language-model scale will be around 10, which might
      be a little confusing if you are not aware of this issue, but is convenient
      for the way the scoring scripts are set up.

   - The default decoding and lattice beams are suitable without modification
     for the 'chain' models, once you use the <code>--acwt 1.0</code> option.
     However, they won't show the full possible speedup and you can get faster
     decoding by using slightly tighter beams.  By tightening the beam in the
     Switchboard setup we were able to get decoding time down from around 1.5
     times real time to around 0.5 times real time, with only around 0.2\%
     degradation in accuracy (this was with neural net evaluation on the CPU; on
     the GPU it would have been even faster).  Note from Dan: this is all to the best
     of my recollection as I write this; actually the degradation may have been more than
     that.  And bear in mind that this was on high-powered modern server machines
    (single-threaded).

 You might notice in the current example scripts that we use iVectors.  We do so
 just because they generally help a bit, and because the baseline setup we were
 comparing with, uses them.  There is no inherent connection with 'chain'
 models, and no fundamental requirement to use them.


*/

}