// gmm/am-diag-gmm.cc // Copyright 2012 Arnab Ghoshal Johns Hopkins University (Author: Daniel Povey) Karel Vesely // Copyright 2009-2011 Saarland University; Microsoft Corporation; // Georg Stemmer // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include #include #include using std::vector; #include "gmm/am-diag-gmm.h" #include "util/stl-utils.h" #include "tree/clusterable-classes.h" #include "tree/cluster-utils.h" namespace kaldi { AmDiagGmm::~AmDiagGmm() { DeletePointers(&densities_); } void AmDiagGmm::Init(const DiagGmm &proto, int32 num_pdfs) { if (densities_.size() != 0) { KALDI_WARN << "Init() called on a non-empty object. Contents will be " "overwritten"; DeletePointers(&densities_); } if (num_pdfs == 0) { KALDI_WARN << "Init() called with number of pdfs = 0. Will do nothing."; return; } densities_.resize(num_pdfs, NULL); for (vector::iterator itr = densities_.begin(), end = densities_.end(); itr != end; ++itr) { *itr = new DiagGmm(); (*itr)->CopyFromDiagGmm(proto); } } void AmDiagGmm::AddPdf(const DiagGmm &gmm) { if (densities_.size() != 0) // not the first gmm KALDI_ASSERT(gmm.Dim() == this->Dim()); DiagGmm *gmm_ptr = new DiagGmm(); gmm_ptr->CopyFromDiagGmm(gmm); densities_.push_back(gmm_ptr); } void AmDiagGmm::RemovePdf(int32 pdf_index) { KALDI_ASSERT(static_cast(pdf_index) < densities_.size()); delete densities_[pdf_index]; densities_.erase(densities_.begin() + pdf_index); } int32 AmDiagGmm::NumGauss() const { int32 ans = 0; for (size_t i = 0; i < densities_.size(); i++) ans += densities_[i]->NumGauss(); return ans; } void AmDiagGmm::CopyFromAmDiagGmm(const AmDiagGmm &other) { if (densities_.size() != 0) { DeletePointers(&densities_); } densities_.resize(other.NumPdfs(), NULL); for (int32 i = 0, end = densities_.size(); i < end; i++) { densities_[i] = new DiagGmm(); densities_[i]->CopyFromDiagGmm(*other.densities_[i]); } } int32 AmDiagGmm::ComputeGconsts() { int32 num_bad = 0; for (std::vector::iterator itr = densities_.begin(), end = densities_.end(); itr != end; ++itr) { num_bad += (*itr)->ComputeGconsts(); } if (num_bad > 0) KALDI_WARN << "Found " << num_bad << " Gaussian components."; return num_bad; } void AmDiagGmm::SplitByCount(const Vector &state_occs, int32 target_components, float perturb_factor, BaseFloat power, BaseFloat min_count) { int32 gauss_at_start = NumGauss(); std::vector targets; GetSplitTargets(state_occs, target_components, power, min_count, &targets); for (int32 i = 0; i < NumPdfs(); i++) { if (densities_[i]->NumGauss() < targets[i]) densities_[i]->Split(targets[i], perturb_factor); } KALDI_LOG << "Split " << NumPdfs() << " states with target = " << target_components << ", power = " << power << ", perturb_factor = " << perturb_factor << " and min_count = " << min_count << ", split #Gauss from " << gauss_at_start << " to " << NumGauss(); } void AmDiagGmm::MergeByCount(const Vector &state_occs, int32 target_components, BaseFloat power, BaseFloat min_count) { int32 gauss_at_start = NumGauss(); std::vector targets; GetSplitTargets(state_occs, target_components, power, min_count, &targets); for (int32 i = 0; i < NumPdfs(); i++) { if (targets[i] == 0) targets[i] = 1; // can't merge below 1. if (densities_[i]->NumGauss() > targets[i]) densities_[i]->Merge(targets[i]); } KALDI_LOG << "Merged " << NumPdfs() << " states with target = " << target_components << ", power = " << power << " and min_count = " << min_count << ", merged from " << gauss_at_start << " to " << NumGauss(); } void AmDiagGmm::Read(std::istream &in_stream, bool binary) { int32 num_pdfs, dim; ExpectToken(in_stream, binary, ""); ReadBasicType(in_stream, binary, &dim); ExpectToken(in_stream, binary, ""); ReadBasicType(in_stream, binary, &num_pdfs); KALDI_ASSERT(num_pdfs > 0); densities_.reserve(num_pdfs); for (int32 i = 0; i < num_pdfs; i++) { densities_.push_back(new DiagGmm()); densities_.back()->Read(in_stream, binary); KALDI_ASSERT(densities_.back()->Dim() == dim); } } void AmDiagGmm::Write(std::ostream &out_stream, bool binary) const { int32 dim = this->Dim(); if (dim == 0) { KALDI_WARN << "Trying to write empty AmDiagGmm object."; } WriteToken(out_stream, binary, ""); WriteBasicType(out_stream, binary, dim); WriteToken(out_stream, binary, ""); WriteBasicType(out_stream, binary, static_cast(densities_.size())); for (std::vector::const_iterator it = densities_.begin(), end = densities_.end(); it != end; ++it) { (*it)->Write(out_stream, binary); } } void UbmClusteringOptions::Check() { if (ubm_num_gauss > intermediate_num_gauss) KALDI_ERR << "Invalid parameters: --ubm-num_gauss=" << ubm_num_gauss << " > --intermediate-num_gauss=" << intermediate_num_gauss; if (ubm_num_gauss > max_am_gauss) KALDI_ERR << "Invalid parameters: --ubm-num_gauss=" << ubm_num_gauss << " > --max-am-gauss=" << max_am_gauss; if (ubm_num_gauss <= 0) KALDI_ERR << "Invalid parameters: --ubm-num_gauss=" << ubm_num_gauss; if (cluster_varfloor <= 0) KALDI_ERR << "Invalid parameters: --cluster-varfloor=" << cluster_varfloor; if (reduce_state_factor <= 0 || reduce_state_factor > 1) KALDI_ERR << "Invalid parameters: --reduce-state-factor=" << reduce_state_factor; } void ClusterGaussiansToUbm(const AmDiagGmm &am, const Vector &state_occs, UbmClusteringOptions opts, DiagGmm *ubm_out) { opts.Check(); // Make sure the various # of Gaussians make sense. if (am.NumGauss() > opts.max_am_gauss) { KALDI_LOG << "ClusterGaussiansToUbm: first reducing num-gauss from " << am.NumGauss() << " to " << opts.max_am_gauss; AmDiagGmm tmp_am; tmp_am.CopyFromAmDiagGmm(am); BaseFloat power = 1.0, min_count = 1.0; // Make the power 1, which I feel // is appropriate to the way we're doing the overall clustering procedure. tmp_am.MergeByCount(state_occs, opts.max_am_gauss, power, min_count); if (tmp_am.NumGauss() > opts.max_am_gauss) { KALDI_LOG << "Clustered down to " << tmp_am.NumGauss() << "; will not cluster further"; opts.max_am_gauss = tmp_am.NumGauss(); } ClusterGaussiansToUbm(tmp_am, state_occs, opts, ubm_out); return; } int32 num_pdfs = static_cast(am.NumPdfs()), dim = am.Dim(), num_clust_states = static_cast(opts.reduce_state_factor*num_pdfs); Vector tmp_mean(dim); Vector tmp_var(dim); DiagGmm tmp_gmm; vector states; states.reserve(num_pdfs); // NOT resize(); uses push_back. // Replace the GMM for each state with a single Gaussian. KALDI_VLOG(1) << "Merging densities to 1 Gaussian per state."; for (int32 pdf_index = 0; pdf_index < num_pdfs; pdf_index++) { KALDI_VLOG(3) << "Merging Gausians for state : " << pdf_index; tmp_gmm.CopyFromDiagGmm(am.GetPdf(pdf_index)); tmp_gmm.Merge(1); tmp_gmm.GetComponentMean(0, &tmp_mean); tmp_gmm.GetComponentVariance(0, &tmp_var); tmp_var.AddVec2(1.0, tmp_mean); // make it x^2 stats. // It may cause problems downstream if we add states with zero weights (see // KALDI_ASSERT(weight > 0) below), so we put in a very small floor. // These states with tiny weights will later get merged into other states. BaseFloat this_weight = 1.0e-10 + state_occs(pdf_index); tmp_mean.Scale(this_weight); tmp_var.Scale(this_weight); states.push_back(new GaussClusterable(tmp_mean, tmp_var, opts.cluster_varfloor, this_weight)); } // Bottom-up clustering of the Gaussians corresponding to each state, which // gives a partial clustering of states in the 'state_clusters' vector. vector state_clusters; KALDI_VLOG(1) << "Creating " << num_clust_states << " clusters of states."; ClusterBottomUp(states, std::numeric_limits::max(), num_clust_states, NULL /*actual clusters not needed*/, &state_clusters /*get the cluster assignments*/); DeletePointers(&states); // For each cluster of states, create a pool of all the Gaussians in those // states, weighted by the state occupancies. This is done so that initially // only the Gaussians corresponding to "similar" states (similarity as // determined by the previous clustering) are merged. vector< vector > state_clust_gauss; state_clust_gauss.resize(num_clust_states); for (int32 pdf_index = 0; pdf_index < num_pdfs; pdf_index++) { int32 current_cluster = state_clusters[pdf_index]; for (int32 num_gauss = am.GetPdf(pdf_index).NumGauss(), gauss_index = 0; gauss_index < num_gauss; ++gauss_index) { am.GetGaussianMean(pdf_index, gauss_index, &tmp_mean); am.GetGaussianVariance(pdf_index, gauss_index, &tmp_var); tmp_var.AddVec2(1.0, tmp_mean); // make it x^2 stats. // adding 1.0e-10 to the weight will prevent problems later on, see // the line KALDI_ASSERT(weight > 0.0). BaseFloat this_weight = (1.0e-10 + state_occs(pdf_index)) * (am.GetPdf(pdf_index).weights())(gauss_index); tmp_mean.Scale(this_weight); tmp_var.Scale(this_weight); state_clust_gauss[current_cluster].push_back(new GaussClusterable( tmp_mean, tmp_var, opts.cluster_varfloor, this_weight)); } } // This is an unlikely operating scenario, no need to handle this in a more // optimized fashion. if (opts.intermediate_num_gauss > am.NumGauss()) { KALDI_WARN << "Intermediate num_gauss " << opts.intermediate_num_gauss << " is more than num-gauss " << am.NumGauss() << ", reducing it to " << am.NumGauss(); opts.intermediate_num_gauss = am.NumGauss(); } // The compartmentalized clusterer used below does not merge compartments. if (opts.intermediate_num_gauss < num_clust_states) { KALDI_WARN << "Intermediate num_gauss " << opts.intermediate_num_gauss << " is less than # of preclustered states " << num_clust_states << ", increasing it to " << num_clust_states; opts.intermediate_num_gauss = num_clust_states; } KALDI_VLOG(1) << "Merging from " << am.NumGauss() << " Gaussians in the " << "acoustic model, down to " << opts.intermediate_num_gauss << " Gaussians."; vector< vector > gauss_clusters_out; ClusterBottomUpCompartmentalized(state_clust_gauss, std::numeric_limits::max(), opts.intermediate_num_gauss, &gauss_clusters_out, NULL); for (int32 clust_index = 0; clust_index < num_clust_states; clust_index++) DeletePointers(&state_clust_gauss[clust_index]); // Next, put the remaining clustered Gaussians into a single GMM. KALDI_VLOG(1) << "Putting " << opts.intermediate_num_gauss << " Gaussians " << "into a single GMM for final merge step."; Matrix tmp_means(opts.intermediate_num_gauss, dim); Matrix tmp_vars(opts.intermediate_num_gauss, dim); Vector tmp_weights(opts.intermediate_num_gauss); Vector tmp_vec(dim); int32 gauss_index = 0; for (int32 clust_index = 0; clust_index < num_clust_states; clust_index++) { for (int32 i = gauss_clusters_out[clust_index].size()-1; i >=0; --i) { GaussClusterable *this_cluster = static_cast( gauss_clusters_out[clust_index][i]); BaseFloat weight = this_cluster->count(); KALDI_ASSERT(weight > 0.0); tmp_weights(gauss_index) = weight; tmp_vec.CopyFromVec(this_cluster->x_stats()); tmp_vec.Scale(1.0 / weight); tmp_means.CopyRowFromVec(tmp_vec, gauss_index); tmp_vec.CopyFromVec(this_cluster->x2_stats()); tmp_vec.Scale(1.0 / weight); tmp_vec.AddVec2(-1.0, tmp_means.Row(gauss_index)); // x^2 stats to var. tmp_vars.CopyRowFromVec(tmp_vec, gauss_index); gauss_index++; } DeletePointers(&(gauss_clusters_out[clust_index])); } tmp_gmm.Resize(opts.intermediate_num_gauss, dim); tmp_weights.Scale(1.0/tmp_weights.Sum()); tmp_gmm.SetWeights(tmp_weights); tmp_vars.InvertElements(); // need inverse vars... tmp_gmm.SetInvVarsAndMeans(tmp_vars, tmp_means); // Finally, cluster to the desired number of Gaussians in the UBM. if (opts.ubm_num_gauss < tmp_gmm.NumGauss()) { tmp_gmm.Merge(opts.ubm_num_gauss); KALDI_VLOG(1) << "Merged down to " << tmp_gmm.NumGauss() << " Gaussians."; } else { KALDI_WARN << "Not merging Gaussians since " << opts.ubm_num_gauss << " < " << tmp_gmm.NumGauss(); } ubm_out->CopyFromDiagGmm(tmp_gmm); } } // namespace kaldi