Blame view

src/sgmm2/am-sgmm2.h 25.4 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
  // sgmm2/am-sgmm2.h
  
  // Copyright 2009-2011  Microsoft Corporation;  Lukas Burget;
  //                      Saarland University (Author: Arnab Ghoshal);
  //                      Ondrej Glembek;  Yanmin Qian;
  // Copyright 2012-2013  Johns Hopkins University (author: Daniel Povey)
  //                      Liang Lu;  Arnab Ghoshal
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  #ifndef KALDI_SGMM2_AM_SGMM2_H_
  #define KALDI_SGMM2_AM_SGMM2_H_
  
  #include <vector>
  
  #include "base/kaldi-common.h"
  #include "matrix/matrix-lib.h"
  #include "gmm/model-common.h"
  #include "gmm/diag-gmm.h"
  #include "gmm/full-gmm.h"
  #include "itf/options-itf.h"
  #include "util/table-types.h"
  #include "util/kaldi-thread.h"
  
  namespace kaldi {
  /*
    When reading this file, keep in mind two references: the paper
   "The Subspace Gaussian Mixture Model-- a Structured Model for Speech Recognition", by D. Povey,
    L. Burget et. al (Computer Speech and Language, 2011), and
    "The Symmetric Subspace Gaussian Mixture Model": Microsoft Research technical report MSR-TR-2010-138.
    We will refer to these as "the paper" [or "the CSL paper"] and "the techreport".
  
    (1) SSGMM
    
    We'll use the acronym SSGMM to refer to the Symmetric SGMM, and we'll mark in
    the code with "[SSGMM]" things that relate to it.  The technical report
    describes an extention to the originally described model where we have
    speaker-dependent mixture weights.  These are implemented here.  Note: we only
    implement the "more efficient" version of the update for the speaker
    projection vectors \u_i.  There is also an ICASSP paper that describes the
    stuff in the techreport (more briefly), with results, but we don't refer to
    any equation numbers in that.
  
    (2) SCTM
  
    What we implement here has another extension that was not in the CSL paper: an
    extension to the "state-clustered tied mixture" [SCTM] system-- a bit like BBN's
    style of system, except for SGMMs not Gaussians, at the sub-state not Gaussian level.
    We build a first
    tree, at which level the phonetic sub-state vectors are defined, and then a
    "more detailed" tree, at which level we share the sub-state mixture weights.
    In this class, NumPdfs() returns the real number of pdf's (i.e. the #leaves
    of the more detailed tree), and NumPdfGroups() returns the number of groups of
    pdf's that share the sub-state vectors.
    We use the index j2 for indexing 0...NumPdfs()-1 [as it's the "2nd level" of the tree],
    and j1 for indexing 0...NumPdfGroups()-1 [as it's the "1st level" of the tree].
    The weights are stored as c[j2][m].  There is a mapping Pdf2Group(j2) which returns
    the corresponding j1 for a given j2, and Group2PdfList(j1) which returns a vector<int32>
    consisting of the list of j2 indices for that j1. 
    
    The count quantities we store during the accumulation phase could most simply
    be stored as gamma[j2][m][i] (where m is the sub-state index), but this is
    inefficient.  Instead we store them separately as gamma1[j1][m][i] and gamma2[j2][m],
    so each count gets stored in two separate places; this makes the stats more compact.
  
    In this implementation, the normalizers n_{jmi} are now stored as n[j1][m][i],
    without including the log-weight term log c[j2][m].  In the computation of
    state likelihoods, we first compute the log-prob of the data given each of the
    sub-state vectors; and we compute the log-sum of this and the posteriors over
    each of the vectors [treating the weights as 1.0].  Call these
    "pseudo-posteriors".  Then to take into account the contribution of the
    weights in a state j2, we take the dot product of the weight-vector c[j2][...]
    with this vector of pseudo-posteriors.  The log of this dot-product gets added to the
    original log-sum.  
  */
  
  
  struct Sgmm2SplitSubstatesConfig {
    int32 split_substates;
    BaseFloat perturb_factor;
    BaseFloat power;
    BaseFloat max_cond;
    BaseFloat min_count;
    Sgmm2SplitSubstatesConfig(): split_substates(0),
                                 perturb_factor(0.01),
                                 power(0.2),
                                 max_cond(100.0),
                                 min_count(40.0) { }
    void Register(OptionsItf *opts) {
      opts->Register("split-substates", &split_substates, "Increase number of "
                     "substates to this overall target.");
      opts->Register("max-cond-split", &max_cond, "Max condition number of smoothing "
                     "matrix used in substate splitting.");
      opts->Register("perturb-factor", &perturb_factor, "Perturbation factor for "
                     "state vectors while splitting substates.");
      opts->Register("power", &power, "Exponent for substate occupancies used while "
                     "splitting substates.");
      opts->Register("min-count", &min_count, "Minimum allowed count, used in allocating "
                     "sub-states to state in mixture splitting.");
    }
  };
  
  // Caution: this config is probably not used in most of the setups, we generally do the Gaussian
  // selection using separate programs
  struct Sgmm2GselectConfig {
    /// Number of highest-scoring full-covariance Gaussians per frame.
    int32 full_gmm_nbest;
    /// Number of highest-scoring diagonal-covariance Gaussians per frame.
    int32 diag_gmm_nbest;
  
    Sgmm2GselectConfig() {
      full_gmm_nbest = 15;
      diag_gmm_nbest = 50;
    }
  
    void Register(OptionsItf *opts) {
      opts->Register("full-gmm-nbest", &full_gmm_nbest, "Number of highest-scoring"
                     " full-covariance Gaussians selected per frame.");
      opts->Register("diag-gmm-nbest", &diag_gmm_nbest, "Number of highest-scoring"
                     " diagonal-covariance Gaussians selected per frame.");
    }
  };
  
  /** \struct Sgmm2PerFrameDerivedVars
   *  Holds the per-frame precomputed quantities x(t), x_{i}(t), z_{i}(t), and
   *  n_{i}(t) (cf. Eq. (33)-(36)) for the SGMM, as well as the cached Gaussian
   *  selection records.
   */
  struct Sgmm2PerFrameDerivedVars {
    std::vector<int32> gselect;
    Vector<BaseFloat> xt;   ///< x'(t), FMLLR-adapted, dim = [D], eq.(33)
    Matrix<BaseFloat> xti;  ///< x_{i}(t) = x'(t) - o_i(s): dim = [I][D], eq.(34)
    Matrix<BaseFloat> zti;  ///< z_{i}(t), dim = [I][S], eq.(35)
    Vector<BaseFloat> nti;  ///< n_{i}(t), dim = [I], eq.(36) in CSL paper, but
                            ///< [SSGMM] with extra term log b_i^{(s)}, see eq. (24) of
                            ///< techreport.
    
    void Resize(int32 ngauss, int32 feat_dim, int32 phn_dim) { // resizes but does
      // not necessarily zero things.
      if (xt.Dim() != feat_dim) xt.Resize(feat_dim);
      if (xti.NumRows() != ngauss || xti.NumCols() != feat_dim)
        xti.Resize(ngauss, feat_dim);
      if (zti.NumRows() != ngauss || zti.NumCols() != phn_dim)
        zti.Resize(ngauss, phn_dim);
      if (nti.Dim() != ngauss)
        nti.Resize(ngauss);
    }
  };
  
  class AmSgmm2;
  
  class Sgmm2PerSpkDerivedVars {
    // To set this up, call ComputePerSpkDerivedVars from the sgmm object.
   public:  
    void Clear() {
      v_s.Resize(0);
      o_s.Resize(0, 0);
      b_is.Resize(0);
      log_b_is.Resize(0);
      log_d_jms.resize(0);
    }
    bool Empty() { return v_s.Dim() == 0; }
    // caution: after SetSpeakerVector you typically want to
    // use the function AmSgmm::ComputePerSpkDerivedVars
    const Vector<BaseFloat> &GetSpeakerVector() { return v_s; }
    
    void SetSpeakerVector(const Vector<BaseFloat> &v_s_in) {
      v_s.Resize(v_s_in.Dim());
      v_s.CopyFromVec(v_s_in);
    }    
   protected:
    friend class AmSgmm2;
    friend class MleAmSgmm2Accs;
    Vector<BaseFloat> v_s;  ///< Speaker adaptation vector v_^{(s)}. Dim is [T]
    Matrix<BaseFloat> o_s;  ///< Per-speaker offsets o_{i}. Dimension is [I][D]
    Vector<BaseFloat> b_is; /// < [SSGMM]: Eq. (22) in techreport, b_i^{(s)} = \exp(\u_i^T \v^{(s)})
    Vector<BaseFloat> log_b_is; /// < [SSGMM] log of the above (more efficient to store both).
    std::vector<Vector<BaseFloat> > log_d_jms; ///< [SSGMM] normalizers per-speaker and per-substate;
                                               ///< indexed [j1][m].
  };
  
  /// Sgmm2LikelihoodCache caches SGMM likelihoods at two levels: the final
  /// pdf likelihoods, and the sub-state level likelihoods, which means
  /// that with the SCTM system we can avoid redundant computation.
  /// You need to call NextFrame() on the cache, between frames.
  struct Sgmm2LikelihoodCache {
   public:
    // you'll typically initialize with (sgmm.NumGroups(), sgmm.NumPdfs()).
    Sgmm2LikelihoodCache(int32 num_groups, int32 num_pdfs):
        substate_cache(num_groups), pdf_cache(num_pdfs), t(1) { }
    
    struct SubstateCacheElement { // indexed by j1.
      SubstateCacheElement(): t(0) { }
      // The "likes" and "remaining_log_like" quantities store the
      // log-like of the data given each substate vector, in a redundant
      // way, so the likelihood is likes(i) * exp(remaining_log_like).
      // This is to get around problems with numerical range.
      Vector<BaseFloat> likes; 
      BaseFloat remaining_log_like;
      int32 t; // used in detecting "freshness."
    };  
    struct PdfCacheElement { // indexed by j2.
      PdfCacheElement(): t(0) { }
      BaseFloat log_like;
      int32 t; // used in detecting "freshness."
    };
  
    void NextFrame(); // increments t.
    std::vector<SubstateCacheElement> substate_cache; // indexed by j1.
    std::vector<PdfCacheElement> pdf_cache; // indexed by j2.
    int32 t;
  };
  
  
  /** \class AmSgmm2
   *  Class for definition of the subspace Gmm acoustic model
   */
  class AmSgmm2 {
   public:
    AmSgmm2() {}
    void Read(std::istream &is, bool binary);
    void Write(std::ostream &os, bool binary,
               SgmmWriteFlagsType write_params) const;
    
    /// Checks the various components for correct sizes. With wrong sizes,
    /// assertion failure occurs. When the argument is set to true, dimensions of
    /// the various components are printed.
    void Check(bool show_properties = true);
  
    /// Initializes the SGMM parameters from a full-covariance UBM.
    /// The state2group vector maps from a state to the corresponding
    /// cluster of states [i.e. j2 to j1].  For conventionally structured
    /// systems (no 2-level tree), this can just be [ 0 1 ... n-1 ].
    void InitializeFromFullGmm(const FullGmm &gmm,
                               const std::vector<int32> &pdf2group,
                               int32 phn_subspace_dim,
                               int32 spk_subspace_dim,
                               bool speaker_dependent_weights,
                               BaseFloat self_weight); // self_weight relates to
    // initialization of the weights.  if self_weight == 1.0 it means we
    // just have 1 sub-state per group, otherwise we have one per pdf,
    // and each pdf has "self_weight" as its "own" weight.
    
    /// Copies the global parameters from the supplied model, but sets
    /// the state vectors to zero. 
    void CopyGlobalsInitVecs(const AmSgmm2 &other,
                             const std::vector<int32> &pdf2group,
                             BaseFloat self_weight);
    
    /// Used to copy models (useful in update)
    void CopyFromSgmm2(const AmSgmm2 &other,
                      bool copy_normalizers,
                      bool copy_weights);  // copy_weights is to copy w_{jmi} [which are
     // stored, in the symmetric SSGMM.]
    
    /// Computes the top-scoring Gaussian indices (used for pruning of later
    /// stages of computation). Returns frame log-likelihood given selected
    /// Gaussians from full UBM.
    BaseFloat GaussianSelection(const Sgmm2GselectConfig &config,
                                const VectorBase<BaseFloat> &data,
                                std::vector<int32> *gselect) const;
    
    /// This needs to be called with each new frame of data, prior to accumulation
    /// or likelihood evaluation: it computes various pre-computed quantities.
    void ComputePerFrameVars(const VectorBase<BaseFloat> &data,
                             const std::vector<int32> &gselect,
                             const Sgmm2PerSpkDerivedVars &spk_vars,
                             Sgmm2PerFrameDerivedVars *per_frame_vars) const;
  
  
    /// Computes the per-speaker derived vars; assumes vars->v_s is already
    /// set up.
    void ComputePerSpkDerivedVars(Sgmm2PerSpkDerivedVars *vars) const;
    
    /// This does a likelihood computation for a given state using the
    /// pre-selected Gaussian components (in per_frame_vars).  If the
    /// log_prune parameter is nonzero (e.g. 5.0), the LogSumExp() stage is
    /// pruned, which is a significant speedup... smaller values are faster.
    /// Note: you have to call cache->NextFrame() before calling this for
    /// a new frame of data.
    BaseFloat LogLikelihood(const Sgmm2PerFrameDerivedVars &per_frame_vars,
                            int32 j2, // pdf_id
                            Sgmm2LikelihoodCache *cache, // be careful to call NextFrame() when needed!
                            Sgmm2PerSpkDerivedVars *spk_vars,
                            BaseFloat log_prune = 0.0) const;
    
    /// Similar to LogLikelihood() function above, but also computes the posterior
    /// probabilities for the pre-selected Gaussian components and all substates.
    /// This one doesn't use caching to share computation for the groups of
    /// pdfs. [it's less necessary, as most of the time we're doing this from alignments,
    /// or lattices that are quite sparse, so we save little by sharing this.]
    BaseFloat ComponentPosteriors(const Sgmm2PerFrameDerivedVars &per_frame_vars,
                                  int32 j2,
                                  Sgmm2PerSpkDerivedVars *spk_vars,
                                  Matrix<BaseFloat> *post) const;
  
    /// Increases the total number of substates based on the state occupancies.
    void SplitSubstates(const Vector<BaseFloat> &state_occupancies, // [indexed by pdf-id j2]
                        const Sgmm2SplitSubstatesConfig &config);
  
    /// Functions for increasing the phonetic and speaker space dimensions.
    /// The argument norm_xform is a LDA-like feature normalizing transform,
    /// computed by the ComputeFeatureNormalizingTransform function.
    void IncreasePhoneSpaceDim(int32 target_dim,
                               const Matrix<BaseFloat> &norm_xform);
  
    /// Increase the subspace dimension for speakers.  The
    /// boolean "speaker_dependent_weights" argument (for SSGMM)
    /// only makes a difference if increasing the subspace dimension
    /// from zero.
    void IncreaseSpkSpaceDim(int32 target_dim,
                             const Matrix<BaseFloat> &norm_xform,
                             bool speaker_dependent_weights);
  
    /// Computes (and initializes if necessary) derived vars...
    /// for now this is just the normalizers "n" and the diagonal UBM,
    /// and if we have the "u" matrix set up, also the w_jmi_
    /// quantities.
    void ComputeDerivedVars();
  
    /// Computes the data-independent terms in the log-likelihood computation
    /// for each Gaussian component and all substates. Eq. (31)
    void ComputeNormalizers();
    
    /// Computes the weights w_jmi_, which is needed for likelihood evaluation
    /// with SSGMMs.
    void ComputeWeights();
  
    /// Computes the LDA-like pre-transform and its inverse as well as the
    /// eigenvalues of the scatter of the means used in FMLLR estimation.
    void ComputeFmllrPreXform(const Vector<BaseFloat> &pdf_occs,
                              Matrix<BaseFloat> *xform,
                              Matrix<BaseFloat> *inv_xform,
                              Vector<BaseFloat> *diag_mean_scatter) const;
    
    /// Various model dimensions.
    int32 NumPdfs() const { return pdf2group_.size(); }
    int32 NumGroups() const { return group2pdf_.size(); } // relates to SCTM.  # pdf groups,
    // <= NumPdfs().
    int32 Pdf2Group(int32 j2) const; // relates to SCTM.
    int32 NumSubstatesForPdf(int32 j2) const {
      KALDI_ASSERT(j2 < NumPdfs()); return c_[j2].Dim();
    }
    int32 NumSubstatesForGroup(int32 j1) const {
      KALDI_ASSERT(j1 < NumGroups()); return v_[j1].NumRows();
    }
    int32 NumGauss() const { return M_.size(); }
    int32 PhoneSpaceDim() const { return w_.NumCols(); }
    int32 SpkSpaceDim() const { return (N_.size() > 0) ? N_[0].NumCols() : 0; }
    int32 FeatureDim() const { return M_[0].NumRows(); }
  
    /// True if doing SSGMM.
    bool HasSpeakerDependentWeights() const { return (u_.NumRows() != 0); }
  
    bool HasSpeakerSpace() const { return (!N_.empty()); }
    
    void RemoveSpeakerSpace() { N_.clear(); u_.Resize(0, 0); w_jmi_.clear(); }
    
    // [SSGMM] get the quantity d_{jm}^{(s)} and cache it with
    // spk vars if necessary.  Called in accumulation code.
    BaseFloat GetDjms(int32 j1, int32 m,
                      Sgmm2PerSpkDerivedVars *spk_vars) const;
    
    /// Accessors
    const FullGmm & full_ubm() const { return full_ubm_; }
    const DiagGmm & diag_ubm() const { return diag_ubm_; }
    
    
    /// Templated accessors (used to accumulate in different precision)
    template<typename Real>
    void GetInvCovars(int32 gauss_index, SpMatrix<Real> *out) const;
  
    template<typename Real>
    void GetSubstateMean(int32 j1, int32 m, int32 i,
                         VectorBase<Real> *mean_out) const;
      
    template<typename Real>
    void GetNtransSigmaInv(std::vector< Matrix<Real> > *out) const;
  
    template<typename Real>
    void GetSubstateSpeakerMean(int32 j1, int32 substate, int32 gauss,
                                const Sgmm2PerSpkDerivedVars &spk,
                                VectorBase<Real> *mean_out) const;
    
    template<typename Real>
    void GetVarScaledSubstateSpeakerMean(int32 j1, int32 substate,
                                         int32 gauss,
                                         const Sgmm2PerSpkDerivedVars &spk,
                                         VectorBase<Real> *mean_out) const;
  
    /// Computes quantities H = M_i Sigma_i^{-1} M_i^T.
    template<class Real>
    void ComputeH(std::vector< SpMatrix<Real> > *H_i) const;
    
   protected:
    std::vector<int32> pdf2group_;
    std::vector<std::vector<int32> > group2pdf_; // the reverse map.
    
    /// These contain the "background" model associated with the subspace GMM.
    DiagGmm diag_ubm_;
    FullGmm full_ubm_;
  
    /// Globally shared parameters of the subspace GMM.  The various quantities
    /// are: I = number of Gaussians, D = data dimension, S = phonetic subspace
    /// dimension, T = speaker subspace dimension, J2 = number of pdfs, J1 =
    /// number of groups of pdfs (for SCTM), #mix = number of substates [of state
    /// j2 or state-group j1, depending on context].
  
    /// Inverse within-class (full) covariances; dim is [I][D][D].
    std::vector< SpMatrix<BaseFloat> > SigmaInv_;
    /// Phonetic-subspace projections. Dimension is [I][D][S]
    std::vector< Matrix<BaseFloat> > M_;
    /// Speaker-subspace projections. Dimension is [I][D][T]
    std::vector< Matrix<BaseFloat> > N_;
    /// Phonetic-subspace weight projection vectors.  Dimension is [I][S]
    Matrix<BaseFloat> w_;
    /// [SSGMM] Speaker-subspace weight projection vectors. Dimension is [I][T]
    Matrix<BaseFloat> u_;
    
    /// The parameters in a particular SGMM state.
  
    /// v_{jm}, per-state phonetic-subspace vectors. Dimension is [J1][#mix][S].
    std::vector< Matrix<BaseFloat> > v_;
    /// c_{jm}, mixture weights. Dimension is [J2][#mix]
    std::vector< Vector<BaseFloat> > c_;
    /// n_{jim}, per-Gaussian normalizer. Dimension is [J1][I][#mix]
    std::vector< Matrix<BaseFloat> > n_;
    /// [SSGMM] w_{jmi}, dimension is [J1][#mix][I].  Computed from w_ and v_.
    std::vector< Matrix<BaseFloat> > w_jmi_;
  
    // Priors for MAP adaptation of M -- keeping them here for now but they may
    // be moved somewhere else eventually
    // These are parameters of a matrix-variate normal distribution. The means are
    // the unadapted M_i, and we have 2 separate covaraince matrices for the rows
    // and columns of M.
    std::vector< Matrix<BaseFloat> > M_prior_;  // Matrix-variate Gaussian mean
    SpMatrix<BaseFloat> row_cov_inv_;
    SpMatrix<BaseFloat> col_cov_inv_;
  
   private:
    /// Computes quasi-occupancies gamma_i from the state-level occupancies,
    /// assuming model correctness.
    void ComputeGammaI(const Vector<BaseFloat> &state_occupancies,
                       Vector<BaseFloat> *gamma_i) const;
    
    /// Called inside SplitSubstates(); splits substates of one group.
    void SplitSubstatesInGroup(const Vector<BaseFloat> &pdf_occupancies,
                               const Sgmm2SplitSubstatesConfig &opts,
                               const SpMatrix<BaseFloat> &sqrt_H_sm,
                               int32 j1, int32 M);
        
    /// Compute a subset of normalizers; used in multi-threaded implementation.
    void ComputeNormalizersInternal(int32 num_threads, int32 thread,
                                    int32 *entropy_count, double *entropy_sum);
    
    /// The code below is called internally from LogLikelihood() and
    /// ComponentPosteriors().  It computes the per-Gaussian log-likelihods
    /// given each sub-state of the state.  Note: the mixture weights
    /// are not included at this point.
    inline void ComponentLogLikes(const Sgmm2PerFrameDerivedVars &per_frame_vars,
                                  int32 j1,
                                  Sgmm2PerSpkDerivedVars *spk_vars,
                                  Matrix<BaseFloat> *loglikes) const;
  
    
    /// Initializes the matrices M_ and w_.
    void InitializeMw(int32 phn_subspace_dim,
                       const Matrix<BaseFloat> &norm_xform);
    /// Initializes the matrices N_ and [if speaker_dependent_weights==true] u_ 
    void InitializeNu(int32 spk_subspace_dim,                    
                      const Matrix<BaseFloat> &norm_xform,
                      bool speaker_dependent_weights);
    void InitializeVecsAndSubstateWeights(BaseFloat self_weight);
    void InitializeCovars();  ///< initializes the within-class covariances.
  
    void ComputeHsmFromModel(
        const std::vector< SpMatrix<BaseFloat> > &H,
        const Vector<BaseFloat> &state_occupancies,
        SpMatrix<BaseFloat> *H_sm,
        BaseFloat max_cond) const;
  
    void ComputePdfMappings(); // sets up group2pdf_ from pdf2group_.
    /// maps from each pdf (index j2) to the corresponding group of
    /// pdfs (index j1) for SCTM.
    
    KALDI_DISALLOW_COPY_AND_ASSIGN(AmSgmm2);
    friend class ComputeNormalizersClass;
    friend class Sgmm2Project;
    friend class EbwAmSgmm2Updater;
    friend class MleAmSgmm2Accs;
    friend class MleAmSgmm2Updater;
    friend class MleSgmm2SpeakerAccs;
    friend class AmSgmm2Functions;  // misc functions that need access.
    friend class Sgmm2Feature;
  };
  
  template<typename Real>
  inline void AmSgmm2::GetInvCovars(int32 gauss_index,
                                    SpMatrix<Real> *out) const {
    out->Resize(SigmaInv_[gauss_index].NumRows(), kUndefined);
    out->CopyFromSp(SigmaInv_[gauss_index]);
  }
  
  
  template<typename Real>
  inline void AmSgmm2::GetSubstateMean(int32 j1, int32 m, int32 i,
                                      VectorBase<Real> *mean_out) const {
    KALDI_ASSERT(mean_out != NULL);
    KALDI_ASSERT(j1 < NumGroups() && m < NumSubstatesForGroup(j1)
                 && i < NumGauss());
    KALDI_ASSERT(mean_out->Dim() == FeatureDim());
    Vector<BaseFloat> mean_tmp(FeatureDim());
    mean_tmp.AddMatVec(1.0, M_[i], kNoTrans, v_[j1].Row(m), 0.0);
    mean_out->CopyFromVec(mean_tmp);
  }
  
  
  template<typename Real>
  inline void AmSgmm2::GetSubstateSpeakerMean(int32 j1, int32 m, int32 i,
                                              const Sgmm2PerSpkDerivedVars &spk,
                                             VectorBase<Real> *mean_out) const {
    GetSubstateMean(j1, m, i, mean_out);
    if (spk.v_s.Dim() != 0)  // have speaker adaptation...
      mean_out->AddVec(1.0, spk.o_s.Row(i));
  }
  
  template<typename Real>
  void AmSgmm2::GetVarScaledSubstateSpeakerMean(int32 j1, int32 m, int32 i,
                                               const Sgmm2PerSpkDerivedVars &spk,
                                               VectorBase<Real> *mean_out) const {
    Vector<BaseFloat> tmp_mean(mean_out->Dim()), tmp_mean2(mean_out->Dim());
    GetSubstateSpeakerMean(j1, m, i, spk, &tmp_mean);
    tmp_mean2.AddSpVec(1.0, SigmaInv_[i], tmp_mean, 0.0);
    mean_out->CopyFromVec(tmp_mean2);
  }
  
  
  /// Computes the inverse of an LDA transform (without dimensionality reduction)
  /// The computed transform is used in initializing the phonetic and speaker
  /// subspaces, as well as while increasing the dimensions of those spaces.
  void ComputeFeatureNormalizingTransform(const FullGmm &gmm, Matrix<BaseFloat> *xform);
  
  
  /// This is the entry for a single time.
  struct Sgmm2GauPostElement {
    // Need gselect info here, since "posteriors" is  relative to this set of
    // selected Gaussians.
    std::vector<int32> gselect;
    std::vector<int32> tids;  // transition-ids for each entry in "posteriors"
    std::vector<Matrix<BaseFloat> > posteriors;
  };
  
  
  /// indexed by time.
  class Sgmm2GauPost: public std::vector<Sgmm2GauPostElement> {
   public:
    // Add the standard Kaldi Read and Write routines so
    // we can use KaldiObjectHolder with this type.
    explicit Sgmm2GauPost(size_t i) : std::vector<Sgmm2GauPostElement>(i) {}
    Sgmm2GauPost() {}
    void Write(std::ostream &os, bool binary) const;
    void Read(std::istream &is, bool binary);
  };
  
  typedef KaldiObjectHolder<Sgmm2GauPost> Sgmm2GauPostHolder;
  typedef RandomAccessTableReader<Sgmm2GauPostHolder> RandomAccessSgmm2GauPostReader;
  typedef SequentialTableReader<Sgmm2GauPostHolder> SequentialSgmm2GauPostReader;
  typedef TableWriter<Sgmm2GauPostHolder> Sgmm2GauPostWriter;
  
  }  // namespace kaldi
  
  
  #endif  // KALDI_SGMM2_AM_SGMM2_H_