Blame view

src/nnet3/discriminative-training.h 10.2 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
  // nnet3/discriminative-training.h
  
  // Copyright      2012-2015    Johns Hopkins University (author: Daniel Povey)
  // Copyright      2014-2015    Vimal Manohar
  
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  
  #ifndef KALDI_NNET3_DISCRIMINATIVE_TRAINING_H_
  #define KALDI_NNET3_DISCRIMINATIVE_TRAINING_H_
  
  #include "base/kaldi-common.h"
  #include "util/common-utils.h"
  #include "fstext/fstext-lib.h"
  #include "tree/context-dep.h"
  #include "lat/kaldi-lattice.h"
  #include "matrix/kaldi-matrix.h"
  #include "hmm/transition-model.h"
  #include "nnet3/discriminative-supervision.h"
  #include "lat/lattice-functions.h"
  #include "cudamatrix/cu-matrix-lib.h"
  
  namespace kaldi {
  namespace discriminative {
  
  /* Options for discriminative training
   *
   * Legend:
   * mmi - Maximum Mutual Information
   * mpfe - Minimum Phone Frame Error
   * smbr - State Minimum Bayes Risk
   *
   */
  struct DiscriminativeOptions {
    std::string criterion; // one of {"mmi", "mpfe", "smbr"}
                           // If the criterion does not match the supervision
                           // object, the derivatives may not be very accurate
    BaseFloat acoustic_scale; // e.g. 0.1
    bool drop_frames; // for MMI, true if we ignore frames where alignment
                      // pdf-id is not in the lattice.
    bool one_silence_class;  // Affects MPFE and SMBR objectives 
    BaseFloat boost; // for MMI, boosting factor (would be Boosted MMI)... e.g. 0.1.
    
    std::string silence_phones_str; // colon-separated list of integer ids of silence phones,
                                    // for MPFE and SMBR objectives
  
    // Cross-entropy regularization constant.  (e.g. try 0.1).  If nonzero,
    // the network is expected to have an output named 'output-xent', which
    // should have a softmax as its final nonlinearity.
    BaseFloat xent_regularize;
  
    // l2 regularization constant on the 'chain' output; the actual term added to
    // the objf will be -0.5 times this constant times the squared l2 norm.
    // (squared so it's additive across the dimensions).  e.g. try 0.0005.
    BaseFloat l2_regularize;
    
    // Options for debugging discriminative training
    
    // Accumulates gradients wrt nnet outputs
    bool accumulate_gradients;  
    
    // Accumulates nnet output
    bool accumulate_output;     
    
    // Applicable for debugging discriminative training when accumulate_gradients
    // or accumulate_output is true 
    int32 num_pdfs;             
  
    DiscriminativeOptions(): criterion("smbr"), 
                             acoustic_scale(0.1),
                             drop_frames(false),
                             one_silence_class(false),
                             boost(0.0), 
                             xent_regularize(0.0), 
                             l2_regularize(0.0),
                             accumulate_gradients(false), 
                             accumulate_output(false),
                             num_pdfs(0) { }
  
    void Register(OptionsItf *opts) {
      opts->Register("criterion", &criterion, "Criterion, 'mmi'|'mpfe'|'smbr', "
                     "determines the objective function to use.  Should match "
                     "option used when we created the examples.");
      opts->Register("acoustic-scale", &acoustic_scale, "Weighting factor to "
                     "apply to acoustic likelihoods.");
      opts->Register("drop-frames", &drop_frames, "For MMI, if true we drop frames "
                     "with no overlap of num and den pdf-ids");
      opts->Register("boost", &boost, "Boosting factor for boosted MMI (e.g. 0.1)");
      opts->Register("one-silence-class", &one_silence_class, "If true, newer "
                     "behavior which will tend to reduce insertions "
                     "when using MPFE or SMBR objective");
      opts->Register("silence-phones", &silence_phones_str,
                     "For MPFE or SMBR objectives, colon-separated list of "
                     "integer ids of silence phones, e.g. 1:2:3");
      opts->Register("l2-regularize", &l2_regularize, "l2 regularization "
                     "constant for 'chain' output "
                     "of the neural net.");
      opts->Register("xent-regularize", &xent_regularize, "Cross-entropy "
                     "regularization constant for sequence training.  If "
                     "nonzero, the network is expected to have an output "
                     "named 'output-xent', which should have a softmax as "
                     "its final nonlinearity.");
      opts->Register("accumulate-gradients", &accumulate_gradients,
                     "Accumulate gradients wrt nnet output "
                     "for debugging discriminative training");
      opts->Register("accumulate-output", &accumulate_output,
                     "Accumulate nnet output "
                     "for debugging discriminative training");
      opts->Register("num-pdfs", &num_pdfs,
                     "Number of pdfs; "
                     "applicable when accumulate-output or accumulate-gradients "
                     "is true for discriminative training");
    }
  };
  
  struct DiscriminativeObjectiveInfo {
    double tot_t;          // total number of frames
    double tot_t_weighted; // total number of frames times weight.
    double tot_objf;      // for 'mmi', the (weighted) denominator likelihood; for
                          // everything else, the objective function.
    double tot_num_count; // total count of numerator posterior 
    double tot_den_count; // total count of denominator posterior 
    double tot_num_objf;  // for 'mmi', the (weighted) numerator likelihood; for
                          // everything else 0
    
    double tot_l2_term;   // l2 regularization objective
    // l2 regularization constant on the 'chain' output; the actual term added to
    // the objf will be -0.5 times this constant times the squared l2 norm.
    // (squared so it's additive across the dimensions).  e.g. try 0.0005.
  
    // Options for debugging discriminative training
    
    // Accumulates gradients wrt nnet outputs
    bool accumulate_gradients;  
    
    // Accumulates nnet output
    bool accumulate_output;     
    
    // Applicable for debugging discriminative training when accumulate_gradients
    // or accumulate_output is true 
    int32 num_pdfs;             
  
    // Used to accumulates gradients wrt nnet outputs
    // when accumulate_gradients is true
    CuVector<double> gradients;
    // Used to accumulates output when accumulate_output is true
    CuVector<double> output;
  
    // Print statistics for the criterion
    void Print(const std::string &criterion, 
               bool print_avg_gradients = false, 
               bool print_avg_output = false) const;
  
    // Print all accumulated statistics for debugging
    void PrintAll(const std::string &criterion) const {
      Print(criterion, true, true);
    }
  
    // Print the gradient wrt nnet output accumulated for a pdf
    void PrintAvgGradientForPdf(int32 pdf_id) const;
  
    // Add stats from another object
    void Add(const DiscriminativeObjectiveInfo &other);
  
    // Returns the objective function value for the criterion
    inline double TotalObjf(const std::string &criterion) const {
      if (criterion == "mmi") return (tot_num_objf - tot_objf);
      return tot_objf;
    }
  
    // Returns true if accumulate_gradients is true 
    // and the gradients vector has been resized to store the 
    // accumulated gradients
    inline bool AccumulateGradients() const {
      return accumulate_gradients && gradients.Dim() > 0;
    }
  
    // Returns true if accumulate_output is true 
    // and the output vector has been resized to store the 
    // accumulated nnet output 
    inline bool AccumulateOutput() const {
      return accumulate_output && output.Dim() > 0;
    }
  
    // Empty constructor
    DiscriminativeObjectiveInfo();
  
    // Constructor preparing to gradients or output to be accumulated
    DiscriminativeObjectiveInfo(int32 num_pdfs);
  
    // Constructor from config options
    DiscriminativeObjectiveInfo(const DiscriminativeOptions &opts);
    
    // Reset statistics
    void Reset();
    
    void Configure(const DiscriminativeOptions &opts);
  };
  
  /**
     This function does forward-backward on the numerator and denominator 
     lattices and computes derivates wrt to the output for the specified 
     objective function.
  
     @param [in] opts        Struct containing options
     @param [in] tmodel       Transition model
     @param [in] log_priors   Vector of log-priors for pdfs
     @param [in] supervision  The supervision object, containing the numerator
                              and denominator paths. The denominator is 
                              always a lattice. The numerator is an alignment.
     @param [in] nnet_output  The output of the neural net; dimension must equal
                            ((supervision.num_sequences * supervision.frames_per_sequence) by
                              tmodel.NumPdfs()).
  
     @param [out] stats       Statistics accumulated during training such as 
                              the objective function and the total weight.
     @param [out] xent_output_deriv  If non-NULL, then the xent objective derivative
                             (which equals a posterior from the numerator forward-backward,
                             scaled by the supervision weight) is written to here.  This will
                             be used in the cross-entropy regularization code.  
  */
  void ComputeDiscriminativeObjfAndDeriv(
      const DiscriminativeOptions &opts,
      const TransitionModel &tmodel,
      const CuVectorBase<BaseFloat> &log_priors,
      const DiscriminativeSupervision &supervision,
      const CuMatrixBase<BaseFloat> &nnet_output,
      DiscriminativeObjectiveInfo *stats,
      CuMatrixBase<BaseFloat> *nnet_output_deriv,
      CuMatrixBase<BaseFloat> *xent_output_deriv);
  
  }  // namespace discriminative
  }  // namespace kaldi
  
  #endif  // KALDI_NNET3_DISCRIMINATIVE_TRAINING_H_