// fstext/context-fst.cc // Copyright 2018 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "fstext/context-fst.h" #include "base/kaldi-error.h" namespace fst { InverseContextFst::InverseContextFst( Label subsequential_symbol, const vector& phones, const vector& disambig_syms, int32 context_width, int32 central_position): context_width_(context_width), central_position_(central_position), phone_syms_(phones), disambig_syms_(disambig_syms), subsequential_symbol_(subsequential_symbol) { { // This block checks the inputs. KALDI_ASSERT(subsequential_symbol != 0 && disambig_syms_.count(subsequential_symbol) == 0 && phone_syms_.count(subsequential_symbol) == 0); if (phone_syms_.empty()) KALDI_WARN << "Context FST created but there are no phone symbols: probably " "input FST was empty."; KALDI_ASSERT(phone_syms_.count(0) == 0 && disambig_syms_.count(0) == 0 && central_position_ >= 0 && central_position_ < context_width_); for (size_t i = 0; i < phones.size(); i++) { KALDI_ASSERT(disambig_syms_.count(phones[i]) == 0); } } // empty vector, will be the ilabel_info vector that corresponds to epsilon, // in case our FST needs to output epsilons. vector empty_vec; Label epsilon_label = FindLabel(empty_vec); // epsilon_vec is the phonetic context window we have at the very start of a // sequence, meaning "no real phones have been seen yet". vector epsilon_vec(context_width_ - 1, 0); StateId start_state = FindState(epsilon_vec); KALDI_ASSERT(epsilon_label == 0 && start_state == 0); if (context_width_ > central_position_ + 1 && !disambig_syms_.empty()) { // We add a symbol whose sequence representation is [ 0 ], and whose // symbol-id is 1. This is treated as a disambiguation symbol, we call it // #-1 in printed form. It is necessary to ensure that all determinizable // LG's will have determinizable CLG's. The problem it fixes is quite // subtle-- it relates to reordering of disambiguation symbols (they appear // earlier in CLG than in LG, relative to phones), and the fact that if a // disambig symbol appears at the very start of a sequence in CLG, it's not // clear exatly where it appeared on the corresponding sequence at the input // of LG. vector pseudo_eps_vec; pseudo_eps_vec.push_back(0); pseudo_eps_symbol_= FindLabel(pseudo_eps_vec); KALDI_ASSERT(pseudo_eps_symbol_ == 1); } else { pseudo_eps_symbol_ = 0; // use actual epsilon. } } void InverseContextFst::ShiftSequenceLeft(Label label, std::vector *phone_seq) { if (!phone_seq->empty()) { phone_seq->erase(phone_seq->begin()); phone_seq->push_back(label); } } void InverseContextFst::GetFullPhoneSequence( const std::vector &seq, Label label, std::vector *full_phone_sequence) { int32 context_width = context_width_; full_phone_sequence->reserve(context_width); full_phone_sequence->insert(full_phone_sequence->end(), seq.begin(), seq.end()); full_phone_sequence->push_back(label); for (int32 i = central_position_ + 1; i < context_width; i++) { if ((*full_phone_sequence)[i] == subsequential_symbol_) { (*full_phone_sequence)[i] = 0; } } } InverseContextFst::Weight InverseContextFst::Final(StateId s) { KALDI_ASSERT(static_cast(s) < state_seqs_.size()); const vector &phone_context = state_seqs_[s]; KALDI_ASSERT(phone_context.size() == context_width_ - 1); bool has_final_prob; if (central_position_ < context_width_ - 1) { has_final_prob = (phone_context[central_position_] == subsequential_symbol_); // if phone_context[central_position_] != subsequential_symbol_ then we have // pending phones-in-context that we still need to output, so we need to // consume more subsequential symbols before we can terminate. } else { has_final_prob = true; } return has_final_prob ? Weight::One() : Weight::Zero(); } bool InverseContextFst::GetArc(StateId s, Label ilabel, Arc *arc) { KALDI_ASSERT(ilabel != 0 && static_cast(s) < state_seqs_.size() && state_seqs_[s].size() == context_width_ - 1); if (IsDisambigSymbol(ilabel)) { // A disambiguation-symbol self-loop arc. CreateDisambigArc(s, ilabel, arc); return true; } else if (IsPhoneSymbol(ilabel)) { const vector &seq = state_seqs_[s]; if (!seq.empty() && seq.back() == subsequential_symbol_) { return false; // A real phone is not allowed to follow the subsequential // symbol. } // next_seq will be 'seq' shifted left by 1, with 'ilabel' appended. vector next_seq(seq); ShiftSequenceLeft(ilabel, &next_seq); // full-seq will be the full context window of size context_width_. vector full_seq; GetFullPhoneSequence(seq, ilabel, &full_seq); StateId next_s = FindState(next_seq); CreatePhoneOrEpsArc(s, next_s, ilabel, full_seq, arc); return true; } else if (ilabel == subsequential_symbol_) { const vector &seq = state_seqs_[s]; if (central_position_ + 1 == context_width_ || seq[central_position_] == subsequential_symbol_) { // We already had "enough" subsequential symbols in a row and don't want to // accept any more, or we'd be making the subsequential symbol the central phone. return false; } // full-seq will be the full context window of size context_width_. vector full_seq; GetFullPhoneSequence(seq, ilabel, &full_seq); vector next_seq(seq); ShiftSequenceLeft(ilabel, &next_seq); StateId next_s = FindState(next_seq); CreatePhoneOrEpsArc(s, next_s, ilabel, full_seq, arc); return true; } else { KALDI_ERR << "ContextFst: CreateArc, invalid ilabel supplied [confusion " << "about phone list or disambig symbols?]: " << ilabel; } return false; // won't get here. suppress compiler error. } void InverseContextFst::CreateDisambigArc(StateId s, Label ilabel, Arc *arc) { // Creates a self-loop arc corresponding to the disambiguation symbol. vector label_info; // This will be a vector containing just [ -olabel ]. label_info.push_back(-ilabel); // olabel is a disambiguation symbol. Use its negative // so we can more easily distinguish them from phones. Label olabel = FindLabel(label_info); arc->ilabel = ilabel; arc->olabel = olabel; arc->weight = Weight::One(); arc->nextstate = s; // self-loop. } void InverseContextFst::CreatePhoneOrEpsArc(StateId src, StateId dest, Label ilabel, const vector &phone_seq, Arc *arc) { KALDI_PARANOID_ASSERT(phone_seq[central_position_] != subsequential_symbol_); arc->ilabel = ilabel; arc->weight = Weight::One(); arc->nextstate = dest; if (phone_seq[central_position_] == 0) { // This can happen at the beginning of the graph. In this case we don't // output a real phone, we createdt an epsilon arc (but sometimes we need to // use a special disambiguation symbol instead of epsilon). arc->olabel = pseudo_eps_symbol_; } else { // We have a phone in the central position. arc->olabel = FindLabel(phone_seq); } } StdArc::StateId InverseContextFst::FindState(const vector &seq) { // Finds state-id corresponding to this vector of phones. Inserts it if // necessary. KALDI_ASSERT(static_cast(seq.size()) == context_width_ - 1); VectorToStateMap::const_iterator iter = state_map_.find(seq); if (iter == state_map_.end()) { // Not already in map. StateId this_state_id = (StateId)state_seqs_.size(); state_seqs_.push_back(seq); state_map_[seq] = this_state_id; return this_state_id; } else { return iter->second; } } StdArc::Label InverseContextFst::FindLabel(const vector &label_vec) { // Finds the ilabel corresponding to this vector (creates a new ilabel if // necessary). VectorToLabelMap::const_iterator iter = ilabel_map_.find(label_vec); if (iter == ilabel_map_.end()) { // Not already in map. Label this_label = ilabel_info_.size(); ilabel_info_.push_back(label_vec); ilabel_map_[label_vec] = this_label; return this_label; } else { return iter->second; } } void ComposeContext(const vector &disambig_syms_in, int32 context_width, int32 central_position, VectorFst *ifst, VectorFst *ofst, vector > *ilabels_out, bool project_ifst) { KALDI_ASSERT(ifst != NULL && ofst != NULL); KALDI_ASSERT(context_width > 0); KALDI_ASSERT(central_position >= 0); KALDI_ASSERT(central_position < context_width); vector disambig_syms(disambig_syms_in); std::sort(disambig_syms.begin(), disambig_syms.end()); vector all_syms; GetInputSymbols(*ifst, false/*no eps*/, &all_syms); std::sort(all_syms.begin(), all_syms.end()); vector phones; for (size_t i = 0; i < all_syms.size(); i++) if (!std::binary_search(disambig_syms.begin(), disambig_syms.end(), all_syms[i])) phones.push_back(all_syms[i]); // Get subsequential symbol that does not clash with // any disambiguation symbol or symbol in the FST. int32 subseq_sym = 1; if (!all_syms.empty()) subseq_sym = std::max(subseq_sym, all_syms.back() + 1); if (!disambig_syms.empty()) subseq_sym = std::max(subseq_sym, disambig_syms.back() + 1); // if central_position == context_width-1, it's left-context, and no // subsequential symbol is needed. if (central_position != context_width-1) { AddSubsequentialLoop(subseq_sym, ifst); if (project_ifst) { fst::Project(ifst, fst::PROJECT_INPUT); } } InverseContextFst inv_c(subseq_sym, phones, disambig_syms, context_width, central_position); // The following statement is equivalent to the following // (if FSTs had the '*' operator for composition): // (*ofst) = inv(inv_c) * (*ifst) ComposeDeterministicOnDemandInverse(*ifst, &inv_c, ofst); inv_c.SwapIlabelInfo(ilabels_out); } void AddSubsequentialLoop(StdArc::Label subseq_symbol, MutableFst *fst) { typedef StdArc Arc; typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; vector final_states; for (StateIterator > siter(*fst); !siter.Done(); siter.Next()) { StateId s = siter.Value(); if (fst->Final(s) != Weight::Zero()) final_states.push_back(s); } StateId superfinal = fst->AddState(); Arc arc(subseq_symbol, 0, Weight::One(), superfinal); fst->AddArc(superfinal, arc); // loop at superfinal. fst->SetFinal(superfinal, Weight::One()); for (size_t i = 0; i < final_states.size(); i++) { StateId s = final_states[i]; fst->AddArc(s, Arc(subseq_symbol, 0, fst->Final(s), superfinal)); // No, don't remove the final-weights of the original states.. // this is so we can add the subsequential loop in cases where // there is no context, and it won't hurt. // fst->SetFinal(s, Weight::Zero()); arc.nextstate = final_states[i]; } } void WriteILabelInfo(std::ostream &os, bool binary, const vector > &info) { int32 size = info.size(); kaldi::WriteBasicType(os, binary, size); for (int32 i = 0; i < size; i++) { kaldi::WriteIntegerVector(os, binary, info[i]); } } void ReadILabelInfo(std::istream &is, bool binary, vector > *info) { int32 size = info->size(); kaldi::ReadBasicType(is, binary, &size); info->resize(size); for (int32 i = 0; i < size; i++) { kaldi::ReadIntegerVector(is, binary, &((*info)[i])); } } SymbolTable *CreateILabelInfoSymbolTable(const vector > &info, const SymbolTable &phones_symtab, std::string separator, std::string initial_disambig) { // e.g. separator = "/", initial-disambig="#-1" KALDI_ASSERT(!info.empty() && info[0].empty()); SymbolTable *ans = new SymbolTable("ilabel-info-symtab"); int64 s = ans->AddSymbol(phones_symtab.Find(static_cast(0))); assert(s == 0); for (size_t i = 1; i < info.size(); i++) { if (info[i].size() == 0) { KALDI_ERR << "Invalid ilabel-info"; } if (info[i].size() == 1 && info[i][0] <= 0) { if (info[i][0] == 0) { // special symbol at start that we want to call #-1. s = ans->AddSymbol(initial_disambig); if (s != i) { KALDI_ERR << "Disambig symbol " << initial_disambig << " already in vocab"; } } else { std::string disambig_sym = phones_symtab.Find(-info[i][0]); if (disambig_sym == "") { KALDI_ERR << "Disambig symbol " << -info[i][0] << " not in phone symbol-table"; } s = ans->AddSymbol(disambig_sym); if (s != i) { KALDI_ERR << "Disambig symbol " << disambig_sym << " already in vocab"; } } } else { // is a phone-context-window. std::string newsym; for (size_t j = 0; j < info[i].size(); j++) { std::string phonesym = phones_symtab.Find(info[i][j]); if (phonesym == "") { KALDI_ERR << "Symbol " << info[i][j] << " not in phone symbol-table"; } if (j != 0) newsym += separator; newsym += phonesym; } int64 s = ans->AddSymbol(newsym); if (s != static_cast(i)) { KALDI_ERR << "Some problem with duplicate symbols"; } } } return ans; } } // end namespace fst