Blame view

src/chain/chain-training.h 7.1 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
  // chain/chain-training.h
  
  // Copyright       2015  Johns Hopkins University (Author: Daniel Povey)
  
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  
  #ifndef KALDI_CHAIN_CHAIN_TRAINING_H_
  #define KALDI_CHAIN_CHAIN_TRAINING_H_
  
  #include <vector>
  #include <map>
  
  #include "base/kaldi-common.h"
  #include "util/common-utils.h"
  #include "fstext/fstext-lib.h"
  #include "tree/context-dep.h"
  #include "lat/kaldi-lattice.h"
  #include "matrix/kaldi-matrix.h"
  #include "hmm/transition-model.h"
  #include "chain/chain-den-graph.h"
  #include "chain/chain-supervision.h"
  
  namespace kaldi {
  namespace chain {
  
  
  struct ChainTrainingOptions {
    // l2 regularization constant on the 'chain' output; the actual term added to
    // the objf will be -0.5 times this constant times the squared l2 norm.
    // (squared so it's additive across the dimensions).  e.g. try 0.0005.
    BaseFloat l2_regularize;
  
  
    // This is similar to an l2 regularization constant (like l2-regularize) but
    // applied on the part of the nnet output matrix that exceeds the range
    // [-30,30]... this is necessary to avoid things regularly going out of the
    // range that we can do exp() on, since the denominator computation is not in
    // log space and to avoid NaNs we limit the outputs to the range [-30,30].
    BaseFloat out_of_range_regularize;
  
    // Coefficient for 'leaky hmm'.  This means we have an epsilon-transition from
    // each state to a special state with probability one, and then another
    // epsilon-transition from that special state to each state, with probability
    // leaky_hmm_coefficient times [initial-prob of destination state].  Imagine
    // we make two copies of each state prior to doing this, version A and version
    // B, with transition from A to B, so we don't have to consider epsilon loops-
    // or just imagine the coefficient is small enough that we can ignore the
    // epsilon loops.
    // Note: we generally set leaky_hmm_coefficient to 0.1.
    BaseFloat leaky_hmm_coefficient;
  
  
    // Cross-entropy regularization constant.  (e.g. try 0.1).  If nonzero,
    // the network is expected to have an output named 'output-xent', which
    // should have a softmax as its final nonlinearity.
    BaseFloat xent_regularize;
  
    ChainTrainingOptions(): l2_regularize(0.0), out_of_range_regularize(0.01),
                            leaky_hmm_coefficient(1.0e-05),
                            xent_regularize(0.0) { }
  
    void Register(OptionsItf *opts) {
      opts->Register("l2-regularize", &l2_regularize, "l2 regularization "
                     "constant for 'chain' training, applied to the output "
                     "of the neural net.");
      opts->Register("out-of-range-regularize", &out_of_range_regularize,
                     "Constant that controls how much we penalize the nnet output "
                     "being outside the range [-30,30].  This is needed because we "
                     "limit it to that range in the denominator computation (which "
                     "is to avoid NaNs because it is not done in log space.");
      opts->Register("leaky-hmm-coefficient", &leaky_hmm_coefficient, "Coefficient "
                     "that allows transitions from each HMM state to each other "
                     "HMM state, to ensure gradual forgetting of context (can "
                     "improve generalization).  For numerical reasons, may not be "
                     "exactly zero.");
      opts->Register("xent-regularize", &xent_regularize, "Cross-entropy "
                     "regularization constant for 'chain' training.  If "
                     "nonzero, the network is expected to have an output "
                     "named 'output-xent', which should have a softmax as "
                     "its final nonlinearity.");
    }
  };
  
  
  /**
     This function does both the numerator and denominator parts of the 'chain'
     computation in one call.
  
     @param [in] opts        Struct containing options
     @param [in] den_graph   The denominator graph, derived from denominator fst.
     @param [in] supervision  The supervision object, containing the supervision
                              paths and constraints on the alignment as an FST
     @param [in] nnet_output  The output of the neural net; dimension must equal
                            ((supervision.num_sequences * supervision.frames_per_sequence) by
                              den_graph.NumPdfs()).  The rows are ordered as: all sequences
                              for frame 0; all sequences for frame 1; etc.
     @param [out] objf       The [num - den] objective function computed for this
                             example; you'll want to divide it by 'tot_weight' before
                             displaying it.
     @param [out] l2_term  The l2 regularization term in the objective function, if
                             the --l2-regularize option is used.  To be added to 'o
     @param [out] weight     The weight to normalize the objective function by;
                             equals supervision.weight * supervision.num_sequences *
                             supervision.frames_per_sequence.
     @param [out] nnet_output_deriv  The derivative of the objective function w.r.t.
                             the neural-net output.  Only written to if non-NULL.
                             You don't have to zero this before passing to this function,
                             we zero it internally.
     @param [out] xent_output_deriv  If non-NULL, then the numerator part of the derivative
                             (which equals a posterior from the numerator
                             forward-backward, scaled by the supervision weight)
                             is written to here (this function will set it to the
                             correct size first; doing it this way reduces the
                             peak memory use).  xent_output_deriv will be used in
                             the cross-entropy regularization code; it is also
                             used in computing the cross-entropy objective value.
  */
  void ComputeChainObjfAndDeriv(const ChainTrainingOptions &opts,
                                const DenominatorGraph &den_graph,
                                const Supervision &supervision,
                                const CuMatrixBase<BaseFloat> &nnet_output,
                                BaseFloat *objf,
                                BaseFloat *l2_term,
                                BaseFloat *weight,
                                CuMatrixBase<BaseFloat> *nnet_output_deriv,
                                CuMatrix<BaseFloat> *xent_output_deriv = NULL);
  
  
  
  }  // namespace chain
  }  // namespace kaldi
  
  #endif  // KALDI_CHAIN_CHAIN_TRAINING_H_