// fstext/fstext-utils-inl.h // Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) // 2014 Telepoint Global Hosting Service, LLC. (Author: David Snyder) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_FSTEXT_FSTEXT_UTILS_INL_H_ #define KALDI_FSTEXT_FSTEXT_UTILS_INL_H_ #include #include "base/kaldi-common.h" #include "util/stl-utils.h" #include "util/text-utils.h" #include "util/kaldi-io.h" #include "fstext/factor.h" #include "fstext/pre-determinize.h" #include "fstext/determinize-star.h" #include #include #include namespace fst { template typename Arc::Label HighestNumberedOutputSymbol(const Fst &fst) { typename Arc::Label ans = 0; for (StateIterator > siter(fst); !siter.Done(); siter.Next()) { typename Arc::StateId s = siter.Value(); for (ArcIterator > aiter(fst, s); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); ans = std::max(ans, arc.olabel); } } return ans; } template typename Arc::Label HighestNumberedInputSymbol(const Fst &fst) { typename Arc::Label ans = 0; for (StateIterator > siter(fst); !siter.Done(); siter.Next()) { typename Arc::StateId s = siter.Value(); for (ArcIterator > aiter(fst, s); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); ans = std::max(ans, arc.ilabel); } } return ans; } template typename Arc::StateId NumArcs(const ExpandedFst &fst) { typedef typename Arc::StateId StateId; StateId num_arcs = 0; for (StateId s = 0; s < fst.NumStates(); s++) num_arcs += fst.NumArcs(s); return num_arcs; } template void GetOutputSymbols(const Fst &fst, bool include_eps, vector *symbols) { KALDI_ASSERT_IS_INTEGER_TYPE(I); std::set all_syms; for (StateIterator > siter(fst); !siter.Done(); siter.Next()) { typename Arc::StateId s = siter.Value(); for (ArcIterator > aiter(fst, s); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); all_syms.insert(arc.olabel); } } // Remove epsilon, if instructed. if (!include_eps && !all_syms.empty() && *all_syms.begin() == 0) all_syms.erase(0); KALDI_ASSERT(symbols != NULL); kaldi::CopySetToVector(all_syms, symbols); } template void GetInputSymbols(const Fst &fst, bool include_eps, vector *symbols) { KALDI_ASSERT_IS_INTEGER_TYPE(I); unordered_set all_syms; for (StateIterator > siter(fst); !siter.Done(); siter.Next()) { typename Arc::StateId s = siter.Value(); for (ArcIterator > aiter(fst, s); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); all_syms.insert(arc.ilabel); } } // Remove epsilon, if instructed. if (!include_eps && all_syms.count(0) != 0) all_syms.erase(0); KALDI_ASSERT(symbols != NULL); kaldi::CopySetToVector(all_syms, symbols); std::sort(symbols->begin(), symbols->end()); } template void RemoveSomeInputSymbols(const vector &to_remove, MutableFst *fst) { KALDI_ASSERT_IS_INTEGER_TYPE(I); RemoveSomeInputSymbolsMapper mapper(to_remove); Map(fst, mapper); } template class MapInputSymbolsMapper { public: Arc operator ()(const Arc &arc_in) { Arc ans = arc_in; if (ans.ilabel > 0 && ans.ilabel < static_cast((*symbol_mapping_).size())) ans.ilabel = (*symbol_mapping_)[ans.ilabel]; return ans; } MapFinalAction FinalAction() { return MAP_NO_SUPERFINAL; } MapSymbolsAction InputSymbolsAction() { return MAP_CLEAR_SYMBOLS; } MapSymbolsAction OutputSymbolsAction() { return MAP_COPY_SYMBOLS; } uint64 Properties(uint64 props) const { // Not tested. bool remove_epsilons = (symbol_mapping_->size() > 0 && (*symbol_mapping_)[0] != 0); bool add_epsilons = (symbol_mapping_->size() > 1 && *std::min_element(symbol_mapping_->begin()+1, symbol_mapping_->end()) == 0); // remove the following as we don't know now if any of them are true. uint64 props_to_remove = kAcceptor|kNotAcceptor|kIDeterministic|kNonIDeterministic| kILabelSorted|kNotILabelSorted; if (remove_epsilons) props_to_remove |= kEpsilons|kIEpsilons; if (add_epsilons) props_to_remove |= kNoEpsilons|kNoIEpsilons; uint64 props_to_add = 0; if (remove_epsilons && !add_epsilons) props_to_add |= kNoEpsilons|kNoIEpsilons; return (props & ~props_to_remove) | props_to_add; } // initialize with copy = false only if the "to_remove" argument will not be deleted // in the lifetime of this object. MapInputSymbolsMapper(const vector &to_remove, bool copy) { KALDI_ASSERT_IS_INTEGER_TYPE(I); if (copy) symbol_mapping_ = new vector (to_remove); else symbol_mapping_ = &to_remove; owned = copy; } ~MapInputSymbolsMapper() { if (owned && symbol_mapping_ != NULL) delete symbol_mapping_; } private: bool owned; const vector *symbol_mapping_; }; template void MapInputSymbols(const vector &symbol_mapping, MutableFst *fst) { KALDI_ASSERT_IS_INTEGER_TYPE(I); // false == don't copy the "symbol_mapping", retain pointer-- // safe since short-lived object. MapInputSymbolsMapper mapper(symbol_mapping, false); Map(fst, mapper); } template bool GetLinearSymbolSequence(const Fst &fst, vector *isymbols_out, vector *osymbols_out, typename Arc::Weight *tot_weight_out) { typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; Weight tot_weight = Weight::One(); vector ilabel_seq; vector olabel_seq; StateId cur_state = fst.Start(); if (cur_state == kNoStateId) { // empty sequence. if (isymbols_out != NULL) isymbols_out->clear(); if (osymbols_out != NULL) osymbols_out->clear(); if (tot_weight_out != NULL) *tot_weight_out = Weight::Zero(); return true; } while (1) { Weight w = fst.Final(cur_state); if (w != Weight::Zero()) { // is final.. tot_weight = Times(w, tot_weight); if (fst.NumArcs(cur_state) != 0) return false; if (isymbols_out != NULL) *isymbols_out = ilabel_seq; if (osymbols_out != NULL) *osymbols_out = olabel_seq; if (tot_weight_out != NULL) *tot_weight_out = tot_weight; return true; } else { if (fst.NumArcs(cur_state) != 1) return false; ArcIterator > iter(fst, cur_state); // get the only arc. const Arc &arc = iter.Value(); tot_weight = Times(arc.weight, tot_weight); if (arc.ilabel != 0) ilabel_seq.push_back(arc.ilabel); if (arc.olabel != 0) olabel_seq.push_back(arc.olabel); cur_state = arc.nextstate; } } } // see fstext-utils.h for comment. template void ConvertNbestToVector(const Fst &fst, vector > *fsts_out) { typedef typename Arc::Weight Weight; typedef typename Arc::StateId StateId; fsts_out->clear(); StateId start_state = fst.Start(); if (start_state == kNoStateId) return; // No output. size_t n_arcs = fst.NumArcs(start_state); bool start_is_final = (fst.Final(start_state) != Weight::Zero()); fsts_out->reserve(n_arcs + (start_is_final ? 1 : 0)); if (start_is_final) { fsts_out->resize(fsts_out->size() + 1); StateId start_state_out = fsts_out->back().AddState(); fsts_out->back().SetFinal(start_state_out, fst.Final(start_state)); } for (ArcIterator > start_aiter(fst, start_state); !start_aiter.Done(); start_aiter.Next()) { fsts_out->resize(fsts_out->size() + 1); VectorFst &ofst = fsts_out->back(); const Arc &first_arc = start_aiter.Value(); StateId cur_state = start_state, cur_ostate = ofst.AddState(); ofst.SetStart(cur_ostate); StateId next_ostate = ofst.AddState(); ofst.AddArc(cur_ostate, Arc(first_arc.ilabel, first_arc.olabel, first_arc.weight, next_ostate)); cur_state = first_arc.nextstate; cur_ostate = next_ostate; while (1) { size_t this_n_arcs = fst.NumArcs(cur_state); KALDI_ASSERT(this_n_arcs <= 1); // or it violates our assumptions // about the input. if (this_n_arcs == 1) { KALDI_ASSERT(fst.Final(cur_state) == Weight::Zero()); // or problem with ShortestPath. ArcIterator > aiter(fst, cur_state); const Arc &arc = aiter.Value(); next_ostate = ofst.AddState(); ofst.AddArc(cur_ostate, Arc(arc.ilabel, arc.olabel, arc.weight, next_ostate)); cur_state = arc.nextstate; cur_ostate = next_ostate; } else { KALDI_ASSERT(fst.Final(cur_state) != Weight::Zero()); // or problem with ShortestPath. ofst.SetFinal(cur_ostate, fst.Final(cur_state)); break; } } } } // see fstext-utils.sh for comment. template void NbestAsFsts(const Fst &fst, size_t n, vector > *fsts_out) { KALDI_ASSERT(n > 0); KALDI_ASSERT(fsts_out != NULL); VectorFst nbest_fst; ShortestPath(fst, &nbest_fst, n); ConvertNbestToVector(nbest_fst, fsts_out); } template void MakeLinearAcceptorWithAlternatives(const vector > &labels, MutableFst *ofst) { typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; ofst->DeleteStates(); StateId cur_state = ofst->AddState(); ofst->SetStart(cur_state); for (size_t i = 0; i < labels.size(); i++) { KALDI_ASSERT(labels[i].size() != 0); StateId next_state = ofst->AddState(); for (size_t j = 0; j < labels[i].size(); j++) { Arc arc(labels[i][j], labels[i][j], Weight::One(), next_state); ofst->AddArc(cur_state, arc); } cur_state = next_state; } ofst->SetFinal(cur_state, Weight::One()); } template void MakeLinearAcceptor(const vector &labels, MutableFst *ofst) { typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; ofst->DeleteStates(); StateId cur_state = ofst->AddState(); ofst->SetStart(cur_state); for (size_t i = 0; i < labels.size(); i++) { StateId next_state = ofst->AddState(); Arc arc(labels[i], labels[i], Weight::One(), next_state); ofst->AddArc(cur_state, arc); cur_state = next_state; } ofst->SetFinal(cur_state, Weight::One()); } template void GetSymbols(const SymbolTable &symtab, bool include_eps, vector *syms_out) { KALDI_ASSERT(syms_out != NULL); syms_out->clear(); for (SymbolTableIterator iter(symtab); !iter.Done(); iter.Next()) { if (include_eps || iter.Value() != 0) { syms_out->push_back(iter.Value()); KALDI_ASSERT(syms_out->back() == iter.Value()); // an integer-range thing. } } } template void SafeDeterminizeWrapper(MutableFst *ifst, MutableFst *ofst, float delta) { typename Arc::Label highest_sym = HighestNumberedInputSymbol(*ifst); vector extra_syms; PreDeterminize(ifst, (typename Arc::Label)(highest_sym+1), &extra_syms); DeterminizeStar(*ifst, ofst, delta); RemoveSomeInputSymbols(extra_syms, ofst); // remove the extra symbols. } template void SafeDeterminizeMinimizeWrapper(MutableFst *ifst, VectorFst *ofst, float delta) { typename Arc::Label highest_sym = HighestNumberedInputSymbol(*ifst); vector extra_syms; PreDeterminize(ifst, (typename Arc::Label)(highest_sym+1), &extra_syms); DeterminizeStar(*ifst, ofst, delta); RemoveSomeInputSymbols(extra_syms, ofst); // remove the extra symbols. RemoveEpsLocal(ofst); // this is "safe" and will never hurt. MinimizeEncoded(ofst, delta); } inline void DeterminizeStarInLog(VectorFst *fst, float delta, bool *debug_ptr, int max_states) { // DeterminizeStarInLog determinizes 'fst' in the log semiring, using // the DeterminizeStar algorithm (which also removes epsilons). ArcSort(fst, ILabelCompare()); // helps DeterminizeStar to be faster. VectorFst *fst_log = new VectorFst; // Want to determinize in log semiring. Cast(*fst, fst_log); VectorFst tmp; *fst = tmp; // make fst empty to free up memory. [actually may make no difference..] VectorFst *fst_det_log = new VectorFst; DeterminizeStar(*fst_log, fst_det_log, delta, debug_ptr, max_states); Cast(*fst_det_log, fst); delete fst_log; delete fst_det_log; } inline void DeterminizeInLog(VectorFst *fst) { // DeterminizeInLog determinizes 'fst' in the log semiring. ArcSort(fst, ILabelCompare()); // helps DeterminizeStar to be faster. VectorFst *fst_log = new VectorFst; // Want to determinize in log semiring. Cast(*fst, fst_log); VectorFst tmp; *fst = tmp; // make fst empty to free up memory. [actually may make no difference..] VectorFst *fst_det_log = new VectorFst; Determinize(*fst_log, fst_det_log); Cast(*fst_det_log, fst); delete fst_log; delete fst_det_log; } // make it inline to avoid having to put it in a .cc file. // destructive algorithm (changes ifst as well as ofst). inline void SafeDeterminizeMinimizeWrapperInLog(VectorFst *ifst, VectorFst *ofst, float delta) { VectorFst *ifst_log = new VectorFst; // Want to determinize in log semiring. Cast(*ifst, ifst_log); VectorFst *ofst_log = new VectorFst; SafeDeterminizeWrapper(ifst_log, ofst_log, delta); Cast(*ofst_log, ofst); delete ifst_log; delete ofst_log; RemoveEpsLocal(ofst); // this is "safe" and will never hurt. Do this in tropical, which is important. MinimizeEncoded(ofst, delta); // Non-deterministic minimization will fail in log semiring so do it with StdARc. } inline void SafeDeterminizeWrapperInLog(VectorFst *ifst, VectorFst *ofst, float delta) { VectorFst *ifst_log = new VectorFst; // Want to determinize in log semiring. Cast(*ifst, ifst_log); VectorFst *ofst_log = new VectorFst; SafeDeterminizeWrapper(ifst_log, ofst_log, delta); Cast(*ofst_log, ofst); delete ifst_log; delete ofst_log; } template void RemoveWeights(MutableFst *ifst) { typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; for (StateIterator > siter(*ifst); !siter.Done(); siter.Next()) { StateId s = siter.Value(); for (MutableArcIterator > aiter(ifst, s); !aiter.Done(); aiter.Next()) { Arc arc(aiter.Value()); arc.weight = Weight::One(); aiter.SetValue(arc); } if (ifst->Final(s) != Weight::Zero()) ifst->SetFinal(s, Weight::One()); } ifst->SetProperties(kUnweighted, kUnweighted); } // Used in PrecedingInputSymbolsAreSame (non-functor version), and // similar routines. template struct IdentityFunction { typedef T Arg; typedef T Result; T operator () (const T &t) const { return t; } }; template bool PrecedingInputSymbolsAreSame(bool start_is_epsilon, const Fst &fst) { IdentityFunction f; return PrecedingInputSymbolsAreSameClass(start_is_epsilon, fst, f); } template // F is functor type from labels to classes. bool PrecedingInputSymbolsAreSameClass(bool start_is_epsilon, const Fst &fst, const F &f) { typedef typename F::Result ClassType; typedef typename Arc::StateId StateId; vector classes; ClassType noClass = f(kNoLabel); if (start_is_epsilon) { StateId start_state = fst.Start(); if (start_state < 0 || start_state == kNoStateId) return true; // empty fst-- doesn't matter. classes.resize(start_state+1, noClass); classes[start_state] = 0; } for (StateIterator > siter(fst); !siter.Done(); siter.Next()) { StateId s = siter.Value(); for (ArcIterator > aiter(fst, s); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); if (classes.size() <= arc.nextstate) classes.resize(arc.nextstate+1, noClass); if (classes[arc.nextstate] == noClass) classes[arc.nextstate] = f(arc.ilabel); else if (classes[arc.nextstate] != f(arc.ilabel)) return false; } } return true; } template bool FollowingInputSymbolsAreSame(bool end_is_epsilon, const Fst &fst) { IdentityFunction f; return FollowingInputSymbolsAreSameClass(end_is_epsilon, fst, f); } template bool FollowingInputSymbolsAreSameClass(bool end_is_epsilon, const Fst &fst, const F &f) { typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; typedef typename F::Result ClassType; const ClassType noClass = f(kNoLabel), epsClass = f(0); for (StateIterator > siter(fst); !siter.Done(); siter.Next()) { StateId s = siter.Value(); ClassType c = noClass; for (ArcIterator > aiter(fst, s); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); if (c == noClass) c = f(arc.ilabel); else if (c != f(arc.ilabel)) return false; } if (end_is_epsilon && c != noClass && c != epsClass && fst.Final(s) != Weight::Zero()) return false; } return true; } template void MakePrecedingInputSymbolsSame(bool start_is_epsilon, MutableFst *fst) { IdentityFunction f; MakePrecedingInputSymbolsSameClass(start_is_epsilon, fst, f); } template void MakePrecedingInputSymbolsSameClass(bool start_is_epsilon, MutableFst *fst, const F &f) { typedef typename F::Result ClassType; typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; vector classes; ClassType noClass = f(kNoLabel); ClassType epsClass = f(0); if (start_is_epsilon) { // treat having-start-state as epsilon in-transition. StateId start_state = fst->Start(); if (start_state < 0 || start_state == kNoStateId) // empty FST. return; classes.resize(start_state+1, noClass); classes[start_state] = epsClass; } // Find bad states (states with multiple input-symbols into them). std::set bad_states; // states that we need to change. for (StateIterator > siter(*fst); !siter.Done(); siter.Next()) { StateId s = siter.Value(); for (ArcIterator > aiter(*fst, s); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); if (classes.size() <= static_cast(arc.nextstate)) classes.resize(arc.nextstate+1, noClass); if (classes[arc.nextstate] == noClass) classes[arc.nextstate] = f(arc.ilabel); else if (classes[arc.nextstate] != f(arc.ilabel)) bad_states.insert(arc.nextstate); } } if (bad_states.empty()) return; // Nothing to do. kaldi::ConstIntegerSet bad_states_ciset(bad_states); // faster lookup. // Work out list of arcs we have to change as (state, arc-offset). // Can't do the actual changes in this pass, since we have to add new // states which invalidates the iterators. vector > arcs_to_change; for (StateIterator > siter(*fst); !siter.Done(); siter.Next()) { StateId s = siter.Value(); for (ArcIterator > aiter(*fst, s); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); if (arc.ilabel != 0 && bad_states_ciset.count(arc.nextstate) != 0) arcs_to_change.push_back(std::make_pair(s, aiter.Position())); } } KALDI_ASSERT(!arcs_to_change.empty()); // since !bad_states.empty(). std::map, StateId> state_map; // state_map is a map from (bad-state, input-symbol-class) to dummy-state. for (size_t i = 0; i < arcs_to_change.size(); i++) { StateId s = arcs_to_change[i].first; ArcIterator > aiter(*fst, s); aiter.Seek(arcs_to_change[i].second); Arc arc = aiter.Value(); // Transition is non-eps transition to "bad" state. Introduce new state (or find // existing one). pair p(arc.nextstate, f(arc.ilabel)); if (state_map.count(p) == 0) { StateId newstate = state_map[p] = fst->AddState(); fst->AddArc(newstate, Arc(0, 0, Weight::One(), arc.nextstate)); } StateId dst_state = state_map[p]; arc.nextstate = dst_state; // Initialize the MutableArcIterator only now, as the call to NewState() // may have invalidated the first arc iterator. MutableArcIterator > maiter(fst, s); maiter.Seek(arcs_to_change[i].second); maiter.SetValue(arc); } } template void MakeFollowingInputSymbolsSame(bool end_is_epsilon, MutableFst *fst) { IdentityFunction f; MakeFollowingInputSymbolsSameClass(end_is_epsilon, fst, f); } template void MakeFollowingInputSymbolsSameClass(bool end_is_epsilon, MutableFst *fst, const F &f) { typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; typedef typename F::Result ClassType; vector bad_states; ClassType noClass = f(kNoLabel); ClassType epsClass = f(0); for (StateIterator > siter(*fst); !siter.Done(); siter.Next()) { StateId s = siter.Value(); ClassType c = noClass; bool bad = false; for (ArcIterator > aiter(*fst, s); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); if (c == noClass) c = f(arc.ilabel); else if (c != f(arc.ilabel)) { bad = true; break; } } if (end_is_epsilon && c != noClass && c != epsClass && fst->Final(s) != Weight::Zero()) bad = true; if (bad) bad_states.push_back(s); } vector my_arcs; for (size_t i = 0; i < bad_states.size(); i++) { StateId s = bad_states[i]; my_arcs.clear(); for (ArcIterator > aiter(*fst, s); !aiter.Done(); aiter.Next()) my_arcs.push_back(aiter.Value()); for (size_t j = 0; j < my_arcs.size(); j++) { Arc &arc = my_arcs[j]; if (arc.ilabel != 0) { StateId newstate = fst->AddState(); // Create a new state for each non-eps arc in original FST, out of each bad state. // Not as optimal as it could be, but does avoid some complicated weight-pushing // issues in which, to maintain stochasticity, we would have to know which semiring // we want to maintain stochasticity in. fst->AddArc(newstate, Arc(arc.ilabel, 0, Weight::One(), arc.nextstate)); MutableArcIterator > maiter(fst, s); maiter.Seek(j); maiter.SetValue(Arc(0, arc.olabel, arc.weight, newstate)); } } } } template VectorFst* MakeLoopFst(const vector *> &fsts) { typedef typename Arc::Weight Weight; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; VectorFst *ans = new VectorFst; StateId loop_state = ans->AddState(); // = 0. ans->SetStart(loop_state); ans->SetFinal(loop_state, Weight::One()); // "cache" is used as an optimization when some of the pointers in "fsts" // may have the same value. unordered_map *, Arc> cache; for (Label i = 0; i < static_cast