Blame view

src/sgmm2/estimate-am-sgmm2-ebw.h 10.8 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
  // sgmm2/estimate-am-sgmm2-ebw.h
  
  // Copyright 2012  Johns Hopkins University (author: Daniel Povey)
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  #ifndef KALDI_SGMM2_ESTIMATE_AM_SGMM2_EBW_H_
  #define KALDI_SGMM2_ESTIMATE_AM_SGMM2_EBW_H_ 1
  
  #include <string>
  #include <vector>
  
  #include "gmm/model-common.h"
  #include "itf/options-itf.h"
  #include "sgmm2/estimate-am-sgmm2.h"
  
  namespace kaldi {
  
  /**
     This header implements a form of Extended Baum-Welch training for SGMMs.
     If you are confused by this comment, see Dan Povey's thesis for an explanation of
     Extended Baum-Welch.
     A note on the EBW (Extended Baum-Welch) updates for the SGMMs... In general there is
     a parameter-specific value D that is similar to the D in EBW for GMMs.  The value of
     D is generally set to:
       E * (denominator-count for that parameter)   +   tau-value for that parameter
     where the tau-values are user-specified parameters that are specific to the type of
     the parameter (e.g. phonetic vector, subspace projection, etc.).  Things are a bit
     more complex for this update than for GMMs, because it's not just a question of picking
     a tau-value for smoothing: there is sometimes a scatter-matrix of some kind (e.g.
     an outer product of vectors, or something) that defines a quadratic objective function
     that we'll add as smoothing.  We have to pick where to get this scatter-matrix from.
     We feel that it's appropriate for the "E" part of the D to get its scatter-matrix from
     denominator stats, and the tau part of the D to get half its scatter-matrix from the
     both the numerator and denominator stats, assigned a weight proportional to how much
     stats there were.  When you see the auxiliary function written out, it's clear why this
     makes sense.
  
   */
  
  struct EbwAmSgmm2Options {
    BaseFloat tau_v; ///<  Smoothing constant for updates of sub-state vectors v_{jm}
    BaseFloat lrate_v; ///< Learning rate used in updating v-- default 0.5
    BaseFloat tau_M; ///<  Smoothing constant for the M quantities (phone-subspace projections)
    BaseFloat lrate_M; ///< Learning rate used in updating M-- default 0.5
    BaseFloat tau_N; ///<  Smoothing constant for the N quantities (speaker-subspace projections)
    BaseFloat lrate_N; ///< Learning rate used in updating N-- default 0.5
    BaseFloat tau_c;  ///< Tau value for smoothing substate weights (c)
    BaseFloat tau_w;  ///< Tau value for smoothing update of phonetic-subspace weight projectsions (w)
    BaseFloat lrate_w; ///< Learning rate used in updating w-- default 1.0
    BaseFloat tau_u;  ///< Tau value for smoothing update of speaker-subspace weight projectsions (u)
    BaseFloat lrate_u; ///< Learning rate used in updating u-- default 1.0
    BaseFloat max_impr_u; ///< Maximum improvement/frame allowed for u [0.25, carried over from ML update.]
    BaseFloat tau_Sigma; ///< Tau value for smoothing covariance-matrices Sigma.
    BaseFloat lrate_Sigma; ///< Learning rate used in updating Sigma-- default 0.5
    BaseFloat min_substate_weight; ///< Minimum allowed weight in a sub-state.
    
    BaseFloat cov_min_value; ///< E.g. 0.5-- the maximum any eigenvalue of a covariance
    /// is allowed to change.  [this is the minimum; the maximum is the inverse of this,
    /// i.e. 2.0 in this case.  For example, 0.9 would constrain the covariance quite tightly,
    /// 0.1 would be a loose setting.
    
    BaseFloat max_cond; ///< large value used in SolveQuadraticProblem.
    BaseFloat epsilon;  ///< very small value used in SolveQuadraticProblem; workaround
    /// for an issue in some implementations of SVD.
    
    EbwAmSgmm2Options() {
      tau_v = 50.0;
      lrate_v = 0.5;
      tau_M = 500.0;
      lrate_M = 0.5;
      tau_N = 500.0;
      lrate_N = 0.5;
      tau_c = 10.0;
      tau_w = 50.0;
      lrate_w = 1.0;
      tau_u = 50.0;
      lrate_u = 1.0;
      max_impr_u = 0.25;
      tau_Sigma = 500.0;
      lrate_Sigma = 0.5;
  
      min_substate_weight = 1.0e-05;
      cov_min_value = 0.5;
      
      max_cond = 1.0e+05;
      epsilon = 1.0e-40;
    }
  
    void Register(OptionsItf *opts) {
      std::string module = "EbwAmSgmm2Options: ";
      opts->Register("tau-v", &tau_v, module+
                     "Smoothing constant for phone vector estimation.");
      opts->Register("lrate-v", &lrate_v, module+
                     "Learning rate constant for phone vector estimation.");
      opts->Register("tau-m", &tau_M, module+
                     "Smoothing constant for estimation of phonetic-subspace projections (M).");
      opts->Register("lrate-m", &lrate_M, module+
                     "Learning rate constant for phonetic-subspace projections.");
      opts->Register("tau-n", &tau_N, module+
                     "Smoothing constant for estimation of speaker-subspace projections (N).");
      opts->Register("lrate-n", &lrate_N, module+
                     "Learning rate constant for speaker-subspace projections.");
      opts->Register("tau-c", &tau_c, module+
                     "Smoothing constant for estimation of substate weights (c)");
      opts->Register("tau-w", &tau_w, module+
                     "Smoothing constant for estimation of phonetic-space weight projections (w)");
      opts->Register("lrate-w", &lrate_w, module+
                     "Learning rate constant for phonetic-space weight-projections (w)");
      opts->Register("tau-u", &tau_u, module+
                     "Smoothing constant for estimation of speaker-space weight projections (u)");
      opts->Register("lrate-u", &lrate_u, module+
                     "Learning rate constant for speaker-space weight-projections (u)");
      opts->Register("tau-sigma", &tau_Sigma, module+
                     "Smoothing constant for estimation of within-class covariances (Sigma)");
      opts->Register("lrate-sigma", &lrate_Sigma, module+
                     "Constant that controls speed of learning for variances (larger->slower)");
      opts->Register("cov-min-value", &cov_min_value, module+
                     "Minimum value that an eigenvalue of the updated covariance matrix can take, "
                     "relative to its old value (maximum is inverse of this.)");
      opts->Register("min-substate-weight", &min_substate_weight, module+
                     "Floor for weights of sub-states.");
      opts->Register("max-cond", &max_cond, module+
                     "Value used in handling singular matrices during update.");
      opts->Register("epsilon", &max_cond, module+
                     "Value used in handling singular matrices during update.");
    }
  };
  
  
  /** \class EbwAmSgmmUpdater
   *  Contains the functions needed to update the SGMM parameters.
   */
  class EbwAmSgmm2Updater {
   public:
    explicit EbwAmSgmm2Updater(const EbwAmSgmm2Options &options):
        options_(options) {}
    
    void Update(const MleAmSgmm2Accs &num_accs,
                const MleAmSgmm2Accs &den_accs,
                AmSgmm2 *model,
                SgmmUpdateFlagsType flags,
                BaseFloat *auxf_change_out,
                BaseFloat *count_out);
      
   protected:
    // The following two classes relate to multi-core parallelization of some
    // phases of the update.
    friend class EbwUpdateWClass;
    friend class EbwUpdatePhoneVectorsClass;
   private:
    EbwAmSgmm2Options options_;
  
    Vector<double> gamma_j_;  ///< State occupancies
  
    double UpdatePhoneVectors(const MleAmSgmm2Accs &num_accs,
                              const MleAmSgmm2Accs &den_accs,
                              const std::vector< SpMatrix<double> > &H,
                              AmSgmm2 *model) const;
    
    // Called from UpdatePhoneVectors; updates a subset of states
    // (relates to multi-threading).
    void UpdatePhoneVectorsInternal(const MleAmSgmm2Accs &num_accs,
                                    const MleAmSgmm2Accs &den_accs,
                                    const std::vector<SpMatrix<double> > &H,
                                    AmSgmm2 *model,
                                    double *auxf_impr,
                                    int32 num_threads,
                                    int32 thread_id) const;
    // Called from UpdatePhoneVectorsInternal
    static void ComputePhoneVecStats(const MleAmSgmm2Accs &accs,
                                     const AmSgmm2 &model,
                                     const std::vector<SpMatrix<double> > &H,
                                     int32 j1,
                                     int32 m,
                                     const Vector<double> &w_jm,
                                     double gamma_jm,
                                     Vector<double> *g_jm,
                                     SpMatrix<double> *H_jm);
                                      
    double UpdateM(const MleAmSgmm2Accs &num_accs,
                   const MleAmSgmm2Accs &den_accs,
                   const std::vector< SpMatrix<double> > &Q_num,
                   const std::vector< SpMatrix<double> > &Q_den,
                   const Vector<double> &gamma_num,
                   const Vector<double> &gamma_den,
                   AmSgmm2 *model) const;
    
    double UpdateN(const MleAmSgmm2Accs &num_accs,
                   const MleAmSgmm2Accs &den_accs,
                   const Vector<double> &gamma_num,
                   const Vector<double> &gamma_den,
                   AmSgmm2 *model) const;
    
    double UpdateVars(const MleAmSgmm2Accs &num_accs,
                      const MleAmSgmm2Accs &den_accs,
                      const Vector<double> &gamma_num,
                      const Vector<double> &gamma_den,
                      const std::vector< SpMatrix<double> > &S_means,
                      AmSgmm2 *model) const;
  
    /// Note: in the discriminative case we do just one iteration of
    /// updating the w quantities.
    double UpdateW(const MleAmSgmm2Accs &num_accs,
                   const MleAmSgmm2Accs &den_accs,
                   const Vector<double> &gamma_num,
                   const Vector<double> &gamma_den,
                   AmSgmm2 *model);
  
  
    double UpdateU(const MleAmSgmm2Accs &num_accs,
                   const MleAmSgmm2Accs &den_accs,
                   const Vector<double> &gamma_num,
                   const Vector<double> &gamma_den,
                   AmSgmm2 *model);
    
    double UpdateSubstateWeights(const MleAmSgmm2Accs &num_accs,
                                 const MleAmSgmm2Accs &den_accs,
                                 AmSgmm2 *model);
  
    KALDI_DISALLOW_COPY_AND_ASSIGN(EbwAmSgmm2Updater);
    EbwAmSgmm2Updater() {}  // Prevent unconfigured updater.
  };
  
  
  }  // namespace kaldi
  
  
  #endif  // KALDI_SGMM2_ESTIMATE_AM_SGMM2_EBW_H_