Blame view
src/transform/regtree-mllr-diag-gmm.h
6.02 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
// transform/regtree-mllr-diag-gmm.h // Copyright 2009-2011 Saarland University; Jan Silovsky // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_TRANSFORM_REGTREE_MLLR_DIAG_GMM_H_ #define KALDI_TRANSFORM_REGTREE_MLLR_DIAG_GMM_H_ #include <vector> #include "base/kaldi-common.h" #include "gmm/am-diag-gmm.h" #include "transform/transform-common.h" #include "transform/regression-tree.h" #include "util/common-utils.h" namespace kaldi { /// Configuration variables for FMLLR transforms struct RegtreeMllrOptions { BaseFloat min_count; ///< Minimum occupancy for computing a transform /// If 'true', find transforms to generate using regression tree. /// If 'false', generate transforms for each baseclass. bool use_regtree; RegtreeMllrOptions(): min_count(1000.0), use_regtree(true) { } void Register(OptionsItf *opts) { opts->Register("mllr-min-count", &min_count, "Minimum count to estimate an MLLR transform."); opts->Register("mllr-use-regtree", &use_regtree, "Use a regression-class tree for MLLR."); } }; /// An MLLR mean transformation is an affine transformation of Gaussian means. class RegtreeMllrDiagGmm { public: RegtreeMllrDiagGmm() {} /// Allocates memory for transform matrix & bias vector void Init(int32 num_xforms, int32 dim); /// Initialize transform matrix to identity and bias vector to zero void SetUnit(); /// Apply the transform(s) to all the Gaussian means in the model void TransformModel(const RegressionTree ®tree, AmDiagGmm *am); /// Get all the transformed means for a given pdf. void GetTransformedMeans(const RegressionTree ®tree, const AmDiagGmm &am, int32 pdf_index, MatrixBase<BaseFloat> *out) const; void Write(std::ostream &out_stream, bool binary) const; void Read(std::istream &in_stream, bool binary); /// Mutators void SetParameters(const MatrixBase<BaseFloat> &mat, int32 regclass); void set_bclass2xforms(const std::vector<int32> &in) { bclass2xforms_ = in; } /// Accessors const std::vector< Matrix<BaseFloat> > xform_matrices() const { return xform_matrices_; } private: /// Transform matrices: size() = num_xforms_ std::vector< Matrix<BaseFloat> > xform_matrices_; int32 num_xforms_; ///< Number of transforms == xform_matrices_.size() /// For each baseclass index of which transform to use; -1 => no xform std::vector<int32> bclass2xforms_; int32 dim_; ///< Dimension of feature vectors // Cannot have copy constructor and assigment operator KALDI_DISALLOW_COPY_AND_ASSIGN(RegtreeMllrDiagGmm); }; inline void RegtreeMllrDiagGmm::SetParameters(const MatrixBase<BaseFloat> &mat, int32 regclass) { xform_matrices_[regclass].CopyFromMat(mat, kNoTrans); } /** Class for computing the maximum-likelihood estimates of the parameters of * an acoustic model that uses diagonal Gaussian mixture models as emission * densities. */ class RegtreeMllrDiagGmmAccs { public: RegtreeMllrDiagGmmAccs() {} ~RegtreeMllrDiagGmmAccs() { DeletePointers(&baseclass_stats_); } void Init(int32 num_bclass, int32 dim); void SetZero(); /// Accumulate stats for a single GMM in the model; returns log likelihood. /// This does not work with multiple feature transforms. BaseFloat AccumulateForGmm(const RegressionTree ®tree, const AmDiagGmm &am, const VectorBase<BaseFloat> &data, int32 pdf_index, BaseFloat weight); /// Accumulate stats for a single Gaussian component in the model. void AccumulateForGaussian(const RegressionTree ®tree, const AmDiagGmm &am, const VectorBase<BaseFloat> &data, int32 pdf_index, int32 gauss_index, BaseFloat weight); void Update(const RegressionTree ®tree, const RegtreeMllrOptions &opts, RegtreeMllrDiagGmm *out_mllr, BaseFloat *auxf_impr, BaseFloat *t) const; void Write(std::ostream &out_stream, bool binary) const; void Read(std::istream &in_stream, bool binary, bool add); /// Accessors int32 Dim() const { return dim_; } int32 NumBaseClasses() const { return num_baseclasses_; } const std::vector<AffineXformStats*> &baseclass_stats() const { return baseclass_stats_; } private: /// Per-baseclass stats; used for accumulation std::vector<AffineXformStats*> baseclass_stats_; int32 num_baseclasses_; ///< Number of baseclasses int32 dim_; ///< Dimension of feature vectors /// Returns the MLLR objective function for a given transform and baseclass. BaseFloat MllrObjFunction(const Matrix<BaseFloat> &xform, int32 bclass_id) const; // Cannot have copy constructor and assigment operator KALDI_DISALLOW_COPY_AND_ASSIGN(RegtreeMllrDiagGmmAccs); }; typedef TableWriter< KaldiObjectHolder<RegtreeMllrDiagGmm> > RegtreeMllrDiagGmmWriter; typedef RandomAccessTableReader< KaldiObjectHolder<RegtreeMllrDiagGmm> > RandomAccessRegtreeMllrDiagGmmReader; typedef RandomAccessTableReaderMapped< KaldiObjectHolder<RegtreeMllrDiagGmm> > RandomAccessRegtreeMllrDiagGmmReaderMapped; typedef SequentialTableReader< KaldiObjectHolder<RegtreeMllrDiagGmm> > RegtreeMllrDiagGmmSeqReader; } // namespace kaldi #endif // KALDI_TRANSFORM_REGTREE_MLLR_DIAG_GMM_H_ |