// transform/regtree-mllr-diag-gmm.cc // Copyright 2009-2011 Saarland University; Jan Silovsky // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include using std::pair; #include using std::vector; #include "transform/regtree-mllr-diag-gmm.h" namespace kaldi { void RegtreeMllrDiagGmm::Init(int32 num_xforms, int32 dim) { if (num_xforms == 0) { // empty transform xform_matrices_.clear(); dim_ = 0; // non-zero dimension is meaningless with empty transform num_xforms_ = 0; bclass2xforms_.clear(); } else { KALDI_ASSERT(dim != 0); // if not empty, dim = 0 is meaningless dim_ = dim; num_xforms_ = num_xforms; xform_matrices_.resize(num_xforms); vector< Matrix >::iterator xform_itr = xform_matrices_.begin(), xform_itr_end = xform_matrices_.end(); for (; xform_itr != xform_itr_end; ++xform_itr) { xform_itr->Resize(dim, dim+1); xform_itr->SetUnit(); } } } void RegtreeMllrDiagGmm::SetUnit() { vector< Matrix >::iterator xform_itr = xform_matrices_.begin(), xform_itr_end = xform_matrices_.end(); for (; xform_itr != xform_itr_end; ++xform_itr) { xform_itr->SetUnit(); } } void RegtreeMllrDiagGmm::TransformModel(const RegressionTree ®tree, AmDiagGmm *am) { KALDI_ASSERT(static_cast(bclass2xforms_.size()) == regtree.NumBaseclasses()); Vector extended_mean(dim_+1), xformed_mean(dim_); for (int32 bclass_index = 0, num_bclasses = regtree.NumBaseclasses(); bclass_index < num_bclasses; ++bclass_index) { int32 xform_index; if ((xform_index = bclass2xforms_[bclass_index]) > -1) { KALDI_ASSERT(xform_index < num_xforms_); const vector< pair > &bclass = regtree.GetBaseclass(bclass_index); for (vector< pair >::const_iterator itr = bclass.begin(), end = bclass.end(); itr != end; ++itr) { SubVector tmp_mean(extended_mean.Range(0, dim_)); am->GetGaussianMean(itr->first, itr->second, &tmp_mean); extended_mean(dim_) = 1.0; xformed_mean.AddMatVec(1.0, xform_matrices_[xform_index], kNoTrans, extended_mean, 0.0); am->SetGaussianMean(itr->first, itr->second, xformed_mean); } // end iterating over Gaussians in baseclass } // else keep the means untransformed } // end iterating over all baseclasses am->ComputeGconsts(); } void RegtreeMllrDiagGmm::GetTransformedMeans(const RegressionTree ®tree, const AmDiagGmm &am, int32 pdf_index, MatrixBase *out) const { KALDI_ASSERT(static_cast(bclass2xforms_.size()) == regtree.NumBaseclasses()); int32 num_gauss = am.GetPdf(pdf_index).NumGauss(); KALDI_ASSERT(out->NumRows() == num_gauss && out->NumCols() == dim_); Vector extended_mean(dim_+1); extended_mean(dim_) = 1.0; for (int32 gauss_index = 0; gauss_index < num_gauss; gauss_index++) { int32 bclass_index = regtree.Gauss2BaseclassId(pdf_index, gauss_index); int32 xform_index = bclass2xforms_[bclass_index]; if (xform_index > -1) { // use a transform KALDI_ASSERT(xform_index < num_xforms_); SubVector tmp_mean(extended_mean.Range(0, dim_)); am.GetGaussianMean(pdf_index, gauss_index, &tmp_mean); SubVector out_row(out->Row(gauss_index)); out_row.AddMatVec(1.0, xform_matrices_[xform_index], kNoTrans, extended_mean, 0.0); } else { // Copy untransformed mean SubVector out_row(out->Row(gauss_index)); am.GetGaussianMean(pdf_index, gauss_index, &out_row); } } } void RegtreeMllrDiagGmm::Write(std::ostream &out, bool binary) const { WriteToken(out, binary, ""); WriteToken(out, binary, ""); WriteBasicType(out, binary, num_xforms_); WriteToken(out, binary, ""); WriteBasicType(out, binary, dim_); vector< Matrix >::const_iterator xform_itr = xform_matrices_.begin(), xform_itr_end = xform_matrices_.end(); for (; xform_itr != xform_itr_end; ++xform_itr) { WriteToken(out, binary, ""); xform_itr->Write(out, binary); } WriteToken(out, binary, ""); WriteIntegerVector(out, binary, bclass2xforms_); WriteToken(out, binary, ""); } void RegtreeMllrDiagGmm::Read(std::istream &in, bool binary) { ExpectToken(in, binary, ""); ExpectToken(in, binary, ""); ReadBasicType(in, binary, &num_xforms_); ExpectToken(in, binary, ""); ReadBasicType(in, binary, &dim_); KALDI_ASSERT(num_xforms_ >= 0 && dim_ >= 0); // can be 0 for empty xform xform_matrices_.resize(num_xforms_); vector< Matrix >::iterator xform_itr = xform_matrices_.begin(), xform_itr_end = xform_matrices_.end(); for (; xform_itr != xform_itr_end; ++xform_itr) { ExpectToken(in, binary, ""); xform_itr->Read(in, binary); KALDI_ASSERT(xform_itr->NumRows() == (xform_itr->NumCols() - 1) && xform_itr->NumRows() == dim_); } ExpectToken(in, binary, ""); ReadIntegerVector(in, binary, &bclass2xforms_); ExpectToken(in, binary, ""); } // ************************************************************************ void RegtreeMllrDiagGmmAccs::Init(int32 num_bclass, int32 dim) { if (num_bclass == 0) { // empty stats DeletePointers(&baseclass_stats_); baseclass_stats_.clear(); num_baseclasses_ = 0; dim_ = 0; // non-zero dimension is meaningless in empty stats } else { KALDI_ASSERT(dim != 0); // if not empty, dim = 0 is meaningless num_baseclasses_ = num_bclass; dim_ = dim; baseclass_stats_.resize(num_baseclasses_); for (vector::iterator it = baseclass_stats_.begin(), end = baseclass_stats_.end(); it != end; ++it) { *it = new AffineXformStats(); (*it)->Init(dim_, dim_); } } } void RegtreeMllrDiagGmmAccs::SetZero() { for (vector::iterator it = baseclass_stats_.begin(), end = baseclass_stats_.end(); it != end; ++it) { (*it)->SetZero(); } } BaseFloat RegtreeMllrDiagGmmAccs::AccumulateForGmm( const RegressionTree ®tree, const AmDiagGmm &am, const VectorBase &data, int32 pdf_index, BaseFloat weight) { const DiagGmm &pdf = am.GetPdf(pdf_index); int32 num_comp = static_cast(pdf.NumGauss()); Vector posterior(num_comp); BaseFloat loglike = pdf.ComponentPosteriors(data, &posterior); posterior.Scale(weight); Vector posterior_d(posterior); Vector data_d(data); Vector inv_var_x(dim_); Vector extended_mean(dim_+1); SpMatrix mean_scatter(dim_+1); for (int32 m = 0; m < num_comp; m++) { unsigned bclass = regtree.Gauss2BaseclassId(pdf_index, m); inv_var_x.CopyFromVec(pdf.inv_vars().Row(m)); inv_var_x.MulElements(data_d); // Using SubVector to stop compiler warning SubVector tmp_mean(extended_mean, 0, dim_); pdf.GetComponentMean(m, &tmp_mean); // modifies extended_mean extended_mean(dim_) = 1.0; mean_scatter.SetZero(); mean_scatter.AddVec2(1.0, extended_mean); baseclass_stats_[bclass]->beta_ += posterior_d(m); baseclass_stats_[bclass]->K_.AddVecVec(posterior_d(m), inv_var_x, extended_mean); vector< SpMatrix > &G = baseclass_stats_[bclass]->G_; for (int32 d = 0; d < dim_; d++) G[d].AddSp((posterior_d(m) * pdf.inv_vars()(m, d)), mean_scatter); } return loglike; } void RegtreeMllrDiagGmmAccs::AccumulateForGaussian( const RegressionTree ®tree, const AmDiagGmm &am, const VectorBase &data, int32 pdf_index, int32 gauss_index, BaseFloat weight) { const DiagGmm &pdf = am.GetPdf(pdf_index); Vector data_d(data); Vector inv_var_x(dim_); Vector extended_mean(dim_+1); double weight_d = static_cast(weight); unsigned bclass = regtree.Gauss2BaseclassId(pdf_index, gauss_index); inv_var_x.CopyFromVec(pdf.inv_vars().Row(gauss_index)); inv_var_x.MulElements(data_d); // Using SubVector to stop compiler warning SubVector tmp_mean(extended_mean, 0, dim_); pdf.GetComponentMean(gauss_index, &tmp_mean); // modifies extended_mean extended_mean(dim_) = 1.0; SpMatrix mean_scatter(dim_+1); mean_scatter.AddVec2(1.0, extended_mean); baseclass_stats_[bclass]->beta_ += weight_d; baseclass_stats_[bclass]->K_.AddVecVec(weight_d, inv_var_x, extended_mean); vector< SpMatrix > &G = baseclass_stats_[bclass]->G_; for (int32 d = 0; d < dim_; d++) G[d].AddSp((weight_d * pdf.inv_vars()(gauss_index, d)), mean_scatter); } void RegtreeMllrDiagGmmAccs::Write(std::ostream &out, bool binary) const { WriteToken(out, binary, ""); WriteToken(out, binary, ""); WriteBasicType(out, binary, num_baseclasses_); WriteToken(out, binary, ""); WriteBasicType(out, binary, dim_); WriteToken(out, binary, ""); vector::const_iterator itr = baseclass_stats_.begin(), end = baseclass_stats_.end(); for ( ; itr != end; ++itr) (*itr)->Write(out, binary); WriteToken(out, binary, ""); } void RegtreeMllrDiagGmmAccs::Read(std::istream &in, bool binary, bool add) { ExpectToken(in, binary, ""); ExpectToken(in, binary, ""); ReadBasicType(in, binary, &num_baseclasses_); ExpectToken(in, binary, ""); ReadBasicType(in, binary, &dim_); KALDI_ASSERT(num_baseclasses_ > 0 && dim_ > 0); baseclass_stats_.resize(num_baseclasses_); ExpectToken(in, binary, ""); vector::iterator itr = baseclass_stats_.begin(), end = baseclass_stats_.end(); for ( ; itr != end; ++itr) { *itr = new AffineXformStats(); (*itr)->Init(dim_, dim_); (*itr)->Read(in, binary, add); } ExpectToken(in, binary, ""); } static void ComputeMllrMatrix(const Matrix &K, const vector< SpMatrix > &G, Matrix *out) { int32 dim = G.size(); Matrix tmp_out(dim, dim+1); for (int32 d = 0; d < dim; d++) { if (G[d].Cond() > 1.0e+9) { KALDI_WARN << "Dim " << d << ": Badly conditioned stats. Setting MLLR " << "transform to unit."; tmp_out.SetUnit(); break; } SpMatrix inv_g(G[d]); // KALDI_LOG << "Dim " << d << ": G: max = " << inv_g.Max() << ", min = " // << inv_g.Min() << ", log det = " << inv_g.LogDet(NULL) // << ", cond = " << inv_g.Cond(); inv_g.Invert(); // KALDI_LOG << "Inv G: max = " << inv_g.Max() << ", min = " << inv_g.Min() // << ", log det = " << inv_g.LogDet(NULL) << ", cond = " // << inv_g.Cond(); tmp_out.Row(d).AddSpVec(1.0, inv_g, K.Row(d), 0.0); } out->CopyFromMat(tmp_out, kNoTrans); } static BaseFloat MllrAuxFunction(const Matrix &xform, const AffineXformStats &stats) { int32 dim = stats.G_.size(); Matrix xform_d(xform); Vector xform_row_g(dim + 1); SubMatrix A(xform_d, 0, dim, 0, dim); double obj = TraceMatMat(xform_d, stats.K_, kTrans); for (int32 d = 0; d < dim; d++) { xform_row_g.AddSpVec(1.0, stats.G_[d], xform_d.Row(d), 0.0); obj -= 0.5 * VecVec(xform_row_g, xform_d.Row(d)); } return obj; } void RegtreeMllrDiagGmmAccs::Update(const RegressionTree ®tree, const RegtreeMllrOptions &opts, RegtreeMllrDiagGmm *out_mllr, BaseFloat *auxf_impr, BaseFloat *t) const { BaseFloat tot_auxf_impr = 0, tot_t = 0; Matrix xform_mat(dim_, dim_ + 1); if (opts.use_regtree) { // estimate transforms using a regression tree vector regclass_stats; vector base2regclass; bool update_xforms = regtree.GatherStats(baseclass_stats_, opts.min_count, &base2regclass, ®class_stats); out_mllr->set_bclass2xforms(base2regclass); // If update_xforms == true, none should be negative, else all should be -1 if (update_xforms) { out_mllr->Init(regclass_stats.size(), dim_); for (int32 rclass_index = 0, num_rclass = regclass_stats.size(); rclass_index < num_rclass; ++rclass_index) { KALDI_ASSERT(regclass_stats[rclass_index]->beta_ >= opts.min_count); xform_mat.SetUnit(); BaseFloat obj_old = MllrAuxFunction(xform_mat, *(regclass_stats[rclass_index])); ComputeMllrMatrix(regclass_stats[rclass_index]->K_, regclass_stats[rclass_index]->G_, &xform_mat); out_mllr->SetParameters(xform_mat, rclass_index); BaseFloat obj_new = MllrAuxFunction(xform_mat, *(regclass_stats[rclass_index])); KALDI_LOG << "MLLR: regclass " << (rclass_index) << ": Objective function impr per frame is " << ((obj_new - obj_old)/regclass_stats[rclass_index]->beta_) << " over " << regclass_stats[rclass_index]->beta_ << " frames."; KALDI_ASSERT(obj_new >= obj_old - (std::abs(obj_new)+std::abs(obj_old))*1.0e-05); tot_t += regclass_stats[rclass_index]->beta_; tot_auxf_impr += obj_new - obj_old; } } else { out_mllr->Init(1, dim_); // Use a unit transform at the root. } DeletePointers(®class_stats); // end of estimation using regression tree } else { // estimate 1 transform per baseclass (if enough count) out_mllr->Init(num_baseclasses_, dim_); vector base2xforms(num_baseclasses_, -1); for (int32 bclass_index = 0; bclass_index < num_baseclasses_; ++bclass_index) { if (baseclass_stats_[bclass_index]->beta_ > opts.min_count) { base2xforms[bclass_index] = bclass_index; xform_mat.SetUnit(); BaseFloat obj_old = MllrAuxFunction(xform_mat, *(baseclass_stats_[bclass_index])); ComputeMllrMatrix(baseclass_stats_[bclass_index]->K_, baseclass_stats_[bclass_index]->G_, &xform_mat); out_mllr->SetParameters(xform_mat, bclass_index); BaseFloat obj_new = MllrAuxFunction(xform_mat, *(baseclass_stats_[bclass_index])); KALDI_LOG << "MLLR: base-class " << (bclass_index) << ": Auxiliary function impr per frame is " << ((obj_new-obj_old)/baseclass_stats_[bclass_index]->beta_); KALDI_ASSERT(obj_new >= obj_old - (std::abs(obj_new)+std::abs(obj_old))*1.0e-05); tot_t += baseclass_stats_[bclass_index]->beta_; tot_auxf_impr += obj_new - obj_old; } else { KALDI_WARN << "For baseclass " << (bclass_index) << " count = " << (baseclass_stats_[bclass_index]->beta_) << " < " << opts.min_count << ": not updating MLLR"; tot_t += baseclass_stats_[bclass_index]->beta_; } } // end looping over all baseclasses out_mllr->set_bclass2xforms(base2xforms); } // end of estimating one transform per baseclass if (auxf_impr != NULL) *auxf_impr = tot_auxf_impr; if (t != NULL) *t = tot_t; } } // namespace kaldi