Blame view
src/gmm/full-gmm.h
9.39 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
// gmm/full-gmm.h // Copyright 2009-2011 Jan Silovsky; // Saarland University (Author: Arnab Ghoshal); // Microsoft Corporation // 2012 Arnab Ghoshal // 2013 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_GMM_FULL_GMM_H_ #define KALDI_GMM_FULL_GMM_H_ #include <utility> #include <vector> #include "base/kaldi-common.h" #include "gmm/model-common.h" #include "matrix/matrix-lib.h" namespace kaldi { class DiagGmm; class FullGmmNormal; // a simplified representation, see full-gmm-normal.h /// Definition for Gaussian Mixture Model with full covariances class FullGmm { /// this makes it a little easier to modify the internals friend class FullGmmNormal; public: /// Empty constructor. FullGmm() : valid_gconsts_(false) {} explicit FullGmm(const FullGmm &gmm): valid_gconsts_(false) { CopyFromFullGmm(gmm); } FullGmm(int32 nMix, int32 dim): valid_gconsts_(false) { Resize(nMix, dim); } /// Resizes arrays to this dim. Does not initialize data. void Resize(int32 nMix, int32 dim); /// Returns the number of mixture components in the GMM int32 NumGauss() const { return weights_.Dim(); } /// Returns the dimensionality of the Gaussian mean vectors int32 Dim() const { return means_invcovars_.NumCols(); } /// Copies from given FullGmm void CopyFromFullGmm(const FullGmm &fullgmm); /// Copies from given DiagGmm void CopyFromDiagGmm(const DiagGmm &diaggmm); /// Returns the log-likelihood of a data point (vector) given the GMM BaseFloat LogLikelihood(const VectorBase<BaseFloat> &data) const; /// Outputs the per-component contributions to the /// log-likelihood void LogLikelihoods(const VectorBase<BaseFloat> &data, Vector<BaseFloat> *loglikes) const; /// Outputs the per-component log-likelihoods of a subset of mixture /// components. Note: indices.size() will equal loglikes->Dim() at output. /// loglikes[i] will correspond to the log-likelihood of the Gaussian /// indexed indices[i]. void LogLikelihoodsPreselect(const VectorBase<BaseFloat> &data, const std::vector<int32> &indices, Vector<BaseFloat> *loglikes) const; /// Get gaussian selection information for one frame. Returns log-like for /// this frame. Output is the best "num_gselect" indices, sorted from best to /// worst likelihood. If "num_gselect" > NumGauss(), sets it to NumGauss(). BaseFloat GaussianSelection(const VectorBase<BaseFloat> &data, int32 num_gselect, std::vector<int32> *output) const; /// Get gaussian selection information for one frame. Returns log-like for /// this frame. Output is the best "num_gselect" indices that were /// preselected, sorted from best to worst likelihood. If "num_gselect" > /// NumGauss(), sets it to NumGauss(). BaseFloat GaussianSelectionPreselect(const VectorBase<BaseFloat> &data, const std::vector<int32> &preselect, int32 num_gselect, std::vector<int32> *output) const; /// Computes the posterior probabilities of all Gaussian components given /// a data point. Returns the log-likehood of the data given the GMM. BaseFloat ComponentPosteriors(const VectorBase<BaseFloat> &data, VectorBase<BaseFloat> *posterior) const; /// Computes the contribution log-likelihood of a data point from a single /// Gaussian component. NOTE: Currently we make no guarantees about what /// happens if one of the variances is zero. BaseFloat ComponentLogLikelihood(const VectorBase<BaseFloat> &data, int32 comp_id) const; /// Sets the gconsts. Returns the number that are "invalid" e.g. because of /// zero weights or variances. int32 ComputeGconsts(); /// Merge the components and remember the order in which the components were /// merged (flat list of pairs) void Split(int32 target_components, float perturb_factor, std::vector<int32> *history = NULL); /// Perturbs the component means with a random vector multiplied by the /// pertrub factor. void Perturb(float perturb_factor); /// Merge the components and remember the order in which the components were /// merged (flat list of pairs) void Merge(int32 target_components, std::vector<int32> *history = NULL); /// Merge the components and remember the order in which the components were /// merged (flat list of pairs); this version only considers merging /// pairs in "preselect_pairs" (or their descendants after merging). /// This is for efficiency, for large models. Returns the delta likelihood. BaseFloat MergePreselect(int32 target_components, const std::vector<std::pair<int32, int32> > &preselect_pairs); void Write(std::ostream &os, bool binary) const; void Read(std::istream &is, bool binary); /// this = rho x source + (1-rho) x this void Interpolate(BaseFloat rho, const FullGmm &source, GmmFlagsType flags = kGmmAll); /// Const accessors const Vector<BaseFloat> &gconsts() const { return gconsts_; } const Vector<BaseFloat> &weights() const { return weights_; } const Matrix<BaseFloat> &means_invcovars() const { return means_invcovars_; } const std::vector<SpMatrix<BaseFloat> > &inv_covars() const { return inv_covars_; } /// Non-const accessors Matrix<BaseFloat> &means_invcovars() { return means_invcovars_; } std::vector<SpMatrix<BaseFloat> > &inv_covars() { return inv_covars_; } /// Mutators for both float or double template<class Real> void SetWeights(const Vector<Real> &w); ///< Set mixure weights /// Use SetMeans to update only the Gaussian means (and not variances) template<class Real> void SetMeans(const Matrix<Real> &m); /// Use SetInvCovarsAndMeans if updating both means and (inverse) covariances template<class Real> void SetInvCovarsAndMeans(const std::vector<SpMatrix<Real> > &invcovars, const Matrix<Real> &means); /// Use this if setting both, in the class's native format. template<class Real> void SetInvCovarsAndMeansInvCovars(const std::vector<SpMatrix<Real> > &invcovars, const Matrix<Real> &means_invcovars); /// Set the (inverse) covariances and recompute means_invcovars_ template<class Real> void SetInvCovars(const std::vector<SpMatrix<Real> > &v); /// Accessor for covariances. template<class Real> void GetCovars(std::vector<SpMatrix<Real> > *v) const; /// Accessor for means. template<class Real> void GetMeans(Matrix<Real> *m) const; /// Accessor for covariances and means template<class Real> void GetCovarsAndMeans(std::vector< SpMatrix<Real> > *covars, Matrix<Real> *means) const; /// Mutators for single component, supports float or double /// Removes single component from model void RemoveComponent(int32 gauss, bool renorm_weights); /// Removes multiple components from model; "gauss" must not have dups. void RemoveComponents(const std::vector<int32> &gauss, bool renorm_weights); /// Accessor for component mean template<class Real> void GetComponentMean(int32 gauss, VectorBase<Real> *out) const; private: /// Equals log(weight) - 0.5 * (log det(var) + mean'*inv(var)*mean) Vector<BaseFloat> gconsts_; bool valid_gconsts_; ///< Recompute gconsts_ if false Vector<BaseFloat> weights_; ///< weights (not log). std::vector<SpMatrix<BaseFloat> > inv_covars_; ///< Inverse covariances Matrix<BaseFloat> means_invcovars_; ///< Means times inverse covariances /// Resizes arrays to this dim. Does not initialize data. void ResizeInvCovars(int32 nMix, int32 dim); // merged_components_logdet computes logdet for merged components // f1, f2 are first-order stats (normalized by zero-order stats) // s1, s2 are second-order stats (normalized by zero-order stats) BaseFloat MergedComponentsLogdet(BaseFloat w1, BaseFloat w2, const VectorBase<BaseFloat> &f1, const VectorBase<BaseFloat> &f2, const SpMatrix<BaseFloat> &s1, const SpMatrix<BaseFloat> &s2) const; const FullGmm &operator=(const FullGmm &other); // Disallow assignment. }; /// ostream operator that calls FullGmm::Write() std::ostream & operator << (std::ostream & rOut, const kaldi::FullGmm &gmm); /// istream operator that calls FullGmm::Read() std::istream & operator >> (std::istream & rIn, kaldi::FullGmm &gmm); } // End namespace kaldi #include "gmm/full-gmm-inl.h" // templated functions. #endif // KALDI_GMM_FULL_GMM_H_ |