// gmmbin/gmm-init-mono.cc // Copyright 2009-2011 Microsoft Corporation // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "base/kaldi-common.h" #include "util/common-utils.h" #include "gmm/am-diag-gmm.h" #include "hmm/hmm-topology.h" #include "hmm/transition-model.h" #include "tree/context-dep.h" namespace kaldi { // This function reads a file like: // 1 2 3 // 4 5 // 6 7 8 // where each line is a list of integer id's of phones (that should have their pdfs shared). void ReadSharedPhonesList(std::string rxfilename, std::vector > *list_out) { list_out->clear(); Input input(rxfilename); std::istream &is = input.Stream(); std::string line; while (std::getline(is, line)) { list_out->push_back(std::vector()); if (!SplitStringToIntegers(line, " \t\r", true, &(list_out->back()))) KALDI_ERR << "Bad line in shared phones list: " << line << " (reading " << PrintableRxfilename(rxfilename) << ")"; std::sort(list_out->rbegin()->begin(), list_out->rbegin()->end()); if (!IsSortedAndUniq(*(list_out->rbegin()))) KALDI_ERR << "Bad line in shared phones list (repeated phone): " << line << " (reading " << PrintableRxfilename(rxfilename) << ")"; } } } // end namespace kaldi int main(int argc, char *argv[]) { try { using namespace kaldi; using kaldi::int32; const char *usage = "Initialize monophone GMM.\n" "Usage: gmm-init-mono \n" "e.g.: \n" " gmm-init-mono topo 39 mono.mdl mono.tree\n"; bool binary = true; std::string train_feats; std::string shared_phones_rxfilename; BaseFloat perturb_factor = 0.0; ParseOptions po(usage); po.Register("binary", &binary, "Write output in binary mode"); po.Register("train-feats", &train_feats, "rspecifier for training features [used to set mean and variance]"); po.Register("shared-phones", &shared_phones_rxfilename, "rxfilename containing, on each line, a list of phones whose pdfs should be shared."); po.Register("perturb-factor", &perturb_factor, "Perturb the means using this fraction of standard deviation."); po.Read(argc, argv); if (po.NumArgs() != 4) { po.PrintUsage(); exit(1); } std::string topo_filename = po.GetArg(1); int dim = atoi(po.GetArg(2).c_str()); KALDI_ASSERT(dim> 0 && dim < 10000); std::string model_filename = po.GetArg(3); std::string tree_filename = po.GetArg(4); Vector glob_inv_var(dim); glob_inv_var.Set(1.0); Vector glob_mean(dim); glob_mean.Set(1.0); if (train_feats != "") { double count = 0.0; Vector var_stats(dim); Vector mean_stats(dim); SequentialDoubleMatrixReader feat_reader(train_feats); for (; !feat_reader.Done(); feat_reader.Next()) { const Matrix &mat = feat_reader.Value(); for (int32 i = 0; i < mat.NumRows(); i++) { count += 1.0; var_stats.AddVec2(1.0, mat.Row(i)); mean_stats.AddVec(1.0, mat.Row(i)); } } if (count == 0) { KALDI_ERR << "no features were seen."; } var_stats.Scale(1.0/count); mean_stats.Scale(1.0/count); var_stats.AddVec2(-1.0, mean_stats); if (var_stats.Min() <= 0.0) KALDI_ERR << "bad variance"; var_stats.InvertElements(); glob_inv_var.CopyFromVec(var_stats); glob_mean.CopyFromVec(mean_stats); } HmmTopology topo; bool binary_in; Input ki(topo_filename, &binary_in); topo.Read(ki.Stream(), binary_in); const std::vector &phones = topo.GetPhones(); std::vector phone2num_pdf_classes (1+phones.back()); for (size_t i = 0; i < phones.size(); i++) phone2num_pdf_classes[phones[i]] = topo.NumPdfClasses(phones[i]); // Now the tree [not really a tree at this point]: ContextDependency *ctx_dep = NULL; if (shared_phones_rxfilename == "") { // No sharing of phones: standard approach. ctx_dep = MonophoneContextDependency(phones, phone2num_pdf_classes); } else { std::vector > shared_phones; ReadSharedPhonesList(shared_phones_rxfilename, &shared_phones); // ReadSharedPhonesList crashes on error. ctx_dep = MonophoneContextDependencyShared(shared_phones, phone2num_pdf_classes); } int32 num_pdfs = ctx_dep->NumPdfs(); AmDiagGmm am_gmm; DiagGmm gmm; gmm.Resize(1, dim); { // Initialize the gmm. Matrix inv_var(1, dim); inv_var.Row(0).CopyFromVec(glob_inv_var); Matrix mu(1, dim); mu.Row(0).CopyFromVec(glob_mean); Vector weights(1); weights.Set(1.0); gmm.SetInvVarsAndMeans(inv_var, mu); gmm.SetWeights(weights); gmm.ComputeGconsts(); } for (int i = 0; i < num_pdfs; i++) am_gmm.AddPdf(gmm); if (perturb_factor != 0.0) { for (int i = 0; i < num_pdfs; i++) am_gmm.GetPdf(i).Perturb(perturb_factor); } // Now the transition model: TransitionModel trans_model(*ctx_dep, topo); { Output ko(model_filename, binary); trans_model.Write(ko.Stream(), binary); am_gmm.Write(ko.Stream(), binary); } // Now write the tree. ctx_dep->Write(Output(tree_filename, binary).Stream(), binary); delete ctx_dep; return 0; } catch(const std::exception &e) { std::cerr << e.what(); return -1; } }