Blame view

src/transform/regtree-mllr-diag-gmm.h 6.02 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
  // transform/regtree-mllr-diag-gmm.h
  
  // Copyright 2009-2011  Saarland University;  Jan Silovsky
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  #ifndef KALDI_TRANSFORM_REGTREE_MLLR_DIAG_GMM_H_
  #define KALDI_TRANSFORM_REGTREE_MLLR_DIAG_GMM_H_
  
  #include <vector>
  
  #include "base/kaldi-common.h"
  #include "gmm/am-diag-gmm.h"
  #include "transform/transform-common.h"
  #include "transform/regression-tree.h"
  #include "util/common-utils.h"
  
  namespace kaldi {
  
  
  ///  Configuration variables for FMLLR transforms
  struct RegtreeMllrOptions {
    BaseFloat min_count;  ///< Minimum occupancy for computing a transform
  
    /// If 'true', find transforms to generate using regression tree.
    /// If 'false', generate transforms for each baseclass.
    bool use_regtree;
  
    RegtreeMllrOptions(): min_count(1000.0), use_regtree(true) { }
  
    void Register(OptionsItf *opts) {
      opts->Register("mllr-min-count", &min_count,
                     "Minimum count to estimate an MLLR transform.");
      opts->Register("mllr-use-regtree", &use_regtree,
                     "Use a regression-class tree for MLLR.");
    }
  };
  
  /// An MLLR mean transformation is an affine transformation of Gaussian means.
  class RegtreeMllrDiagGmm {
   public:
    RegtreeMllrDiagGmm() {}
  
    /// Allocates memory for transform matrix & bias vector
    void Init(int32 num_xforms, int32 dim);
  
    /// Initialize transform matrix to identity and bias vector to zero
    void SetUnit();
  
    /// Apply the transform(s) to all the Gaussian means in the model
    void TransformModel(const RegressionTree &regtree, AmDiagGmm *am);
  
    /// Get all the transformed means for a given pdf.
    void GetTransformedMeans(const RegressionTree &regtree, const AmDiagGmm &am,
                             int32 pdf_index, MatrixBase<BaseFloat> *out) const;
  
    void Write(std::ostream &out_stream, bool binary) const;
    void Read(std::istream &in_stream, bool binary);
  
    /// Mutators
    void SetParameters(const MatrixBase<BaseFloat> &mat, int32 regclass);
    void set_bclass2xforms(const std::vector<int32> &in) { bclass2xforms_ = in; }
  
    /// Accessors
    const std::vector< Matrix<BaseFloat> > xform_matrices() const {
      return xform_matrices_;
    }
  
   private:
    /// Transform matrices: size() = num_xforms_
    std::vector< Matrix<BaseFloat> > xform_matrices_;
    int32 num_xforms_;  ///< Number of transforms == xform_matrices_.size()
    /// For each baseclass index of which transform to use; -1 => no xform
    std::vector<int32> bclass2xforms_;
    int32 dim_;  ///< Dimension of feature vectors
  
    // Cannot have copy constructor and assigment operator
    KALDI_DISALLOW_COPY_AND_ASSIGN(RegtreeMllrDiagGmm);
  };
  
  inline void RegtreeMllrDiagGmm::SetParameters(const MatrixBase<BaseFloat> &mat,
                                                int32 regclass) {
    xform_matrices_[regclass].CopyFromMat(mat, kNoTrans);
  }
  
  /** Class for computing the maximum-likelihood estimates of the parameters of
   *  an acoustic model that uses diagonal Gaussian mixture models as emission
   *  densities.
   */
  class RegtreeMllrDiagGmmAccs {
   public:
    RegtreeMllrDiagGmmAccs() {}
    ~RegtreeMllrDiagGmmAccs() { DeletePointers(&baseclass_stats_); }
  
    void Init(int32 num_bclass, int32 dim);
    void SetZero();
  
    /// Accumulate stats for a single GMM in the model; returns log likelihood.
    /// This does not work with multiple feature transforms.
    BaseFloat AccumulateForGmm(const RegressionTree &regtree,
                               const AmDiagGmm &am,
                               const VectorBase<BaseFloat> &data,
                               int32 pdf_index, BaseFloat weight);
  
    /// Accumulate stats for a single Gaussian component in the model.
    void AccumulateForGaussian(const RegressionTree &regtree,
                               const AmDiagGmm &am,
                               const VectorBase<BaseFloat> &data,
                               int32 pdf_index, int32 gauss_index,
                               BaseFloat weight);
  
    void Update(const RegressionTree &regtree, const RegtreeMllrOptions &opts,
                RegtreeMllrDiagGmm *out_mllr, BaseFloat *auxf_impr,
                BaseFloat *t) const;
  
    void Write(std::ostream &out_stream, bool binary) const;
    void Read(std::istream &in_stream, bool binary, bool add);
  
    /// Accessors
    int32 Dim() const { return dim_; }
    int32 NumBaseClasses() const { return num_baseclasses_; }
    const std::vector<AffineXformStats*> &baseclass_stats() const {
      return baseclass_stats_;
    }
  
   private:
    /// Per-baseclass stats; used for accumulation
    std::vector<AffineXformStats*> baseclass_stats_;
    int32 num_baseclasses_;    ///< Number of baseclasses
    int32 dim_;    ///< Dimension of feature vectors
  
    /// Returns the MLLR objective function for a given transform and baseclass.
    BaseFloat MllrObjFunction(const Matrix<BaseFloat> &xform,
                              int32 bclass_id) const;
  
    // Cannot have copy constructor and assigment operator
    KALDI_DISALLOW_COPY_AND_ASSIGN(RegtreeMllrDiagGmmAccs);
  };
  
  typedef TableWriter< KaldiObjectHolder<RegtreeMllrDiagGmm> >
              RegtreeMllrDiagGmmWriter;
  typedef RandomAccessTableReader< KaldiObjectHolder<RegtreeMllrDiagGmm> >
              RandomAccessRegtreeMllrDiagGmmReader;
  typedef RandomAccessTableReaderMapped< KaldiObjectHolder<RegtreeMllrDiagGmm> >
              RandomAccessRegtreeMllrDiagGmmReaderMapped;
  typedef SequentialTableReader< KaldiObjectHolder<RegtreeMllrDiagGmm> >
              RegtreeMllrDiagGmmSeqReader;
  
  }  // namespace kaldi
  
  #endif  // KALDI_TRANSFORM_REGTREE_MLLR_DIAG_GMM_H_