// tree/context-dep.cc // Copyright 2009-2011 Microsoft Corporation // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "tree/context-dep.h" #include "base/kaldi-math.h" #include "tree/build-tree.h" namespace kaldi { bool ContextDependency::Compute(const std::vector &phoneseq, int32 pdf_class, int32 *pdf_id) const { KALDI_ASSERT(static_cast(phoneseq.size()) == N_); EventType event_vec; event_vec.reserve(N_+1); event_vec.push_back(std::make_pair (static_cast(kPdfClass), // -1 static_cast(pdf_class))); KALDI_COMPILE_TIME_ASSERT(kPdfClass < 0); // or it would not be sorted. for (int32 i = 0;i < N_;i++) { event_vec.push_back(std::make_pair (static_cast(i), static_cast(phoneseq[i]))); KALDI_ASSERT(static_cast(phoneseq[i]) >= 0); } KALDI_ASSERT(pdf_id != NULL); return to_pdf_->Map(event_vec, pdf_id); } ContextDependency *GenRandContextDependency(const std::vector &phone_ids, bool ensure_all_covered, std::vector *hmm_lengths) { KALDI_ASSERT(IsSortedAndUniq(phone_ids)); int32 num_phones = phone_ids.size(); int32 num_stats = 1 + (Rand() % 15) * (Rand() % 15); // up to 14^2 + 1 separate stats. int32 N = 2 + Rand() % 3; // 2, 3 or 4. int32 P = Rand() % N; float ctx_dep_prob = 0.7 + 0.3*RandUniform(); int32 max_phone = *std::max_element(phone_ids.begin(), phone_ids.end()); hmm_lengths->clear(); hmm_lengths->resize(max_phone + 1, -1); std::vector is_ctx_dep(max_phone + 1); for (int32 i = 0; i <= max_phone; i++) { (*hmm_lengths)[i] = 1 + Rand() % 3; is_ctx_dep[i] = (RandUniform() < ctx_dep_prob); // true w.p. ctx_dep_prob. } for (size_t i = 0; i < (size_t) num_phones; i++) KALDI_VLOG(2) << "For idx = " << i << ", (phone_id, hmm_length, is_ctx_dep) == " << (phone_ids[i]) << " " << ((*hmm_lengths)[phone_ids[i]]) << " " << (is_ctx_dep[phone_ids[i]]); // Generate rand stats. BuildTreeStatsType stats; size_t dim = 3 + Rand() % 20; GenRandStats(dim, num_stats, N, P, phone_ids, *hmm_lengths, is_ctx_dep, ensure_all_covered, &stats); // Now build the tree. Questions qopts; int32 num_quest = Rand() % 10, num_iters = rand () % 5; qopts.InitRand(stats, num_quest, num_iters, kAllKeysUnion); // This was tested in build-tree-utils-test.cc float thresh = 100.0 * RandUniform(); EventMap *tree = NULL; std::vector > phone_sets(phone_ids.size()); for (size_t i = 0; i < phone_ids.size(); i++) phone_sets[i].push_back(phone_ids[i]); std::vector share_roots(phone_sets.size(), true), do_split(phone_sets.size(), true); tree = BuildTree(qopts, phone_sets, *hmm_lengths, share_roots, do_split, stats, thresh, 1000, 0.0, P); DeleteBuildTreeStats(&stats); return new ContextDependency(N, P, tree); } ContextDependency *GenRandContextDependencyLarge(const std::vector &phone_ids, int N, int P, bool ensure_all_covered, std::vector *hmm_lengths) { KALDI_ASSERT(IsSortedAndUniq(phone_ids)); int32 num_phones = phone_ids.size(); int32 num_stats = 3000; // each is a separate context. float ctx_dep_prob = 0.9; KALDI_ASSERT(num_phones > 0); hmm_lengths->clear(); int32 max_phone = *std::max_element(phone_ids.begin(), phone_ids.end()); hmm_lengths->resize(max_phone + 1, -1); std::vector is_ctx_dep(max_phone + 1); for (int32 i = 0; i <= max_phone; i++) { (*hmm_lengths)[i] = 1 + Rand() % 3; is_ctx_dep[i] = (RandUniform() < ctx_dep_prob); // true w.p. ctx_dep_prob. } for (size_t i = 0; i < (size_t) num_phones; i++) { KALDI_VLOG(2) << "For idx = "<< i << ", (phone_id, hmm_length, is_ctx_dep) == " << (phone_ids[i]) << " " << ((*hmm_lengths)[phone_ids[i]]) << " " << (is_ctx_dep[phone_ids[i]]); } // Generate rand stats. BuildTreeStatsType stats; size_t dim = 3 + Rand() % 20; GenRandStats(dim, num_stats, N, P, phone_ids, *hmm_lengths, is_ctx_dep, ensure_all_covered, &stats); // Now build the tree. Questions qopts; int32 num_quest = 40, num_iters = 0; qopts.InitRand(stats, num_quest, num_iters, kAllKeysUnion); // This was tested in build-tree-utils-test.cc float thresh = 100.0 * RandUniform(); EventMap *tree = NULL; std::vector > phone_sets(phone_ids.size()); for (size_t i = 0; i < phone_ids.size(); i++) phone_sets[i].push_back(phone_ids[i]); std::vector share_roots(phone_sets.size(), true), do_split(phone_sets.size(), true); tree = BuildTree(qopts, phone_sets, *hmm_lengths, share_roots, do_split, stats, thresh, 1000, 0.0, P); DeleteBuildTreeStats(&stats); return new ContextDependency(N, P, tree); } void ContextDependency::Write (std::ostream &os, bool binary) const { WriteToken(os, binary, "ContextDependency"); WriteBasicType(os, binary, N_); WriteBasicType(os, binary, P_); WriteToken(os, binary, "ToPdf"); to_pdf_->Write(os, binary); WriteToken(os, binary, "EndContextDependency"); } void ContextDependency::Read (std::istream &is, bool binary) { if (to_pdf_) { delete to_pdf_; to_pdf_ = NULL; } ExpectToken(is, binary, "ContextDependency"); ReadBasicType(is, binary, &N_); ReadBasicType(is, binary, &P_); EventMap *to_pdf = NULL; std::string token; ReadToken(is, binary, &token); if (token == "ToLength") { // back-compat. EventMap *to_num_pdf_classes = EventMap::Read(is, binary); delete to_num_pdf_classes; ReadToken(is, binary, &token); } if (token == "ToPdf") { to_pdf = EventMap::Read(is , binary); } else { KALDI_ERR << "Got unexpected token " << token << " reading context-dependency object."; } ExpectToken(is, binary, "EndContextDependency"); to_pdf_ = to_pdf; } void ContextDependency::EnumeratePairs( const std::vector &phones, int32 self_loop_pdf_class, int32 forward_pdf_class, const std::vector &phone_window, unordered_set, PairHasher > *pairs) const { std::vector new_phone_window(phone_window); EventType vec; std::vector forward_pdfs, self_loop_pdfs; // get list of possible forward pdfs vec.clear(); for (size_t i = 0; i < N_; i++) if (phone_window[i] >= 0) vec.push_back(std::make_pair(static_cast(i), static_cast(phone_window[i]))); vec.push_back(std::make_pair(kPdfClass, static_cast(forward_pdf_class))); std::sort(vec.begin(), vec.end()); to_pdf_->MultiMap(vec, &forward_pdfs); SortAndUniq(&forward_pdfs); // get list of possible self-loop pdfs vec.clear(); for (size_t i = 0; i < N_; i++) if (phone_window[i] >= 0) vec.push_back(std::make_pair(static_cast(i), static_cast(phone_window[i]))); vec.push_back(std::make_pair(kPdfClass, static_cast(self_loop_pdf_class))); std::sort(vec.begin(), vec.end()); to_pdf_->MultiMap(vec, &self_loop_pdfs); SortAndUniq(&self_loop_pdfs); if (forward_pdfs.size() == 1 || self_loop_pdfs.size() == 1) { for (size_t m = 0; m < forward_pdfs.size(); m++) for (size_t n = 0; n < self_loop_pdfs.size(); n++) pairs->insert(std::make_pair(forward_pdfs[m], self_loop_pdfs[n])); } else { // Choose 'position' as a phone position in 'context' that's currently // -1, and that is as close as possible to the central position P. int32 position = 0; int32 min_dist = N_ - 1; for (int32 i = 0; i < N_; i++) { int32 dist = (P_ - i > 0) ? (P_ - i) : (i - P_); if (phone_window[i] == -1 && dist < min_dist) { position = i; min_dist = dist; } } KALDI_ASSERT(min_dist < N_); KALDI_ASSERT(position != P_); // The next two lines have to do with how BOS/EOS effects are handled in // phone context. Zero phone value in a non-central position (i.e. not // position P_... and 'position' will never equal P_) means 'there is no // phone here because we're at BOS or EOS'. new_phone_window[position] = 0; EnumeratePairs(phones, self_loop_pdf_class, forward_pdf_class, new_phone_window, pairs); for (size_t i = 0 ; i < phones.size(); i++) { new_phone_window[position] = phones[i]; EnumeratePairs(phones, self_loop_pdf_class, forward_pdf_class, new_phone_window, pairs); } } } void ContextDependency::GetPdfInfo( const std::vector &phones, const std::vector > > &pdf_class_pairs, std::vector > > > *pdf_info) const { KALDI_ASSERT(pdf_info != NULL); pdf_info->resize(1 + *std::max_element(phones.begin(), phones.end())); std::vector phone_window(N_, -1); EventType vec; for (size_t i = 0 ; i < phones.size(); i++) { // loop over phones int32 phone = phones[i]; (*pdf_info)[phone].resize(pdf_class_pairs[phone].size()); for (size_t j = 0; j < pdf_class_pairs[phone].size(); j++) { // loop over pdf_class pairs int32 pdf_class = pdf_class_pairs[phone][j].first, self_loop_pdf_class = pdf_class_pairs[phone][j].second; phone_window[P_] = phone; unordered_set, PairHasher > pairs; EnumeratePairs(phones, self_loop_pdf_class, pdf_class, phone_window, &pairs); unordered_set, PairHasher >::iterator iter = pairs.begin(), end = pairs.end(); for (; iter != end; ++iter) (*pdf_info)[phone][j].push_back(*iter); std::sort( ((*pdf_info)[phone][j]).begin(), ((*pdf_info)[phone][j]).end()); } } } void ContextDependency::GetPdfInfo( const std::vector &phones, const std::vector &num_pdf_classes, // indexed by phone, std::vector > > *pdf_info) const { EventType vec; KALDI_ASSERT(pdf_info != NULL); pdf_info->resize(NumPdfs()); for (size_t i = 0 ; i < phones.size(); i++) { int32 phone = phones[i]; vec.clear(); vec.push_back(std::make_pair(static_cast(P_), static_cast(phone))); // Now get length. KALDI_ASSERT(static_cast(phone) < num_pdf_classes.size()); EventAnswerType len = num_pdf_classes[phone]; for (int32 pos = 0; pos < len; pos++) { vec.resize(2); vec[0] = std::make_pair(static_cast(P_), static_cast(phone)); vec[1] = std::make_pair(kPdfClass, static_cast(pos)); std::sort(vec.begin(), vec.end()); std::vector pdfs; // pdfs that can be at this pos as this phone. to_pdf_->MultiMap(vec, &pdfs); SortAndUniq(&pdfs); if (pdfs.empty()) { KALDI_WARN << "ContextDependency::GetPdfInfo, no pdfs returned for position "<< pos << " of phone " << phone << ". Continuing but this is a serious error."; } for (size_t j = 0; j < pdfs.size(); j++) { KALDI_ASSERT(static_cast(pdfs[j]) < pdf_info->size()); (*pdf_info)[pdfs[j]].push_back(std::make_pair(phone, pos)); } } } for (size_t i = 0; i < pdf_info->size(); i++) { std::sort( ((*pdf_info)[i]).begin(), ((*pdf_info)[i]).end()); KALDI_ASSERT(IsSortedAndUniq( ((*pdf_info)[i]))); // should have no dups. } } ContextDependency* MonophoneContextDependency(const std::vector &phones, const std::vector &phone2num_pdf_classes) { std::vector > phone_sets(phones.size()); for (size_t i = 0; i < phones.size(); i++) phone_sets[i].push_back(phones[i]); std::vector share_roots(phones.size(), false); // don't share roots. // N is context size, P = position of central phone (must be 0). int32 num_leaves = 0, P = 0, N = 1; EventMap *pdf_map = GetStubMap(P, phone_sets, phone2num_pdf_classes, share_roots, &num_leaves); return new ContextDependency(N, P, pdf_map); } ContextDependency* MonophoneContextDependencyShared(const std::vector > &phone_sets, const std::vector &phone2num_pdf_classes) { std::vector share_roots(phone_sets.size(), false); // don't share roots. // N is context size, P = position of central phone (must be 0). int32 num_leaves = 0, P = 0, N = 1; EventMap *pdf_map = GetStubMap(P, phone_sets, phone2num_pdf_classes, share_roots, &num_leaves); return new ContextDependency(N, P, pdf_map); } } // end namespace kaldi.