// fstext/trivial-factor-weight.h // Copyright 2009-2011 Microsoft Corporation // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. // // // This is a modified file from the OpenFST Library v1.2.7 available at // http://www.openfst.org and released under the Apache License Version 2.0. // // // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: allauzen@google.com (Cyril Allauzen) #ifndef KALDI_FSTEXT_TRIVIAL_FACTOR_WEIGHT_H_ #define KALDI_FSTEXT_TRIVIAL_FACTOR_WEIGHT_H_ // TrivialFactorWeight.h This is an extension to factor-weight.h in the OpenFst // code. It is a version of FactorWeight that creates separate states (with // input epsilons) rather than pushing the factors forward. This is for // converting from Gallic FSTs, where you want the result to be a bit more // trivial with input epsilons inserted where there are multiple output symbols. // This has the advantage that it always works, for any input (also I just // prefer this approach). #include using std::unordered_map; #include #include #include #include #include #include namespace fst { template struct TrivialFactorWeightOptions : CacheOptions { typedef typename Arc::Label Label; float delta; Label extra_ilabel; // input label of extra arcs Label extra_olabel; // output label of extra arcs TrivialFactorWeightOptions(const CacheOptions &opts, float d, Label il = 0, Label ol = 0) : CacheOptions(opts), delta(d), extra_ilabel(il), extra_olabel(ol) {} explicit TrivialFactorWeightOptions( float d, Label il = 0, Label ol = 0) : delta(d), extra_ilabel(il), extra_olabel(ol) {} TrivialFactorWeightOptions(): delta(kDelta), extra_ilabel(0), extra_olabel(0) {} }; namespace internal { // Implementation class for TrivialFactorWeight template class TrivialFactorWeightFstImpl : public CacheImpl { public: using CacheImpl::PushArc; using FstImpl::SetType; using FstImpl::SetProperties; using FstImpl::Properties; using FstImpl::SetInputSymbols; using FstImpl::SetOutputSymbols; using CacheBaseImpl< CacheState >::HasStart; using CacheBaseImpl< CacheState >::HasFinal; using CacheBaseImpl< CacheState >::HasArcs; typedef A Arc; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef F FactorIterator; typedef DefaultCacheStore Store; typedef typename Store::State State; struct Element { Element() {} Element(StateId s, Weight w) : state(s), weight(w) {} StateId state; // Input state Id Weight weight; // Residual weight }; TrivialFactorWeightFstImpl(const Fst &fst, const TrivialFactorWeightOptions &opts) : CacheImpl(opts), fst_(fst.Copy()), delta_(opts.delta), extra_ilabel_(opts.extra_ilabel), extra_olabel_(opts.extra_olabel) { SetType("factor-weight"); uint64 props = fst.Properties(kFstProperties, false); SetProperties(FactorWeightProperties(props), kCopyProperties); SetInputSymbols(fst.InputSymbols()); SetOutputSymbols(fst.OutputSymbols()); } TrivialFactorWeightFstImpl(const TrivialFactorWeightFstImpl &impl) : CacheImpl(impl), fst_(impl.fst_->Copy(true)), delta_(impl.delta_), extra_ilabel_(impl.extra_ilabel_), extra_olabel_(impl.extra_olabel_) { SetType("factor-weight"); SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(impl.InputSymbols()); SetOutputSymbols(impl.OutputSymbols()); } StateId Start() { if (!HasStart()) { StateId s = fst_->Start(); if (s == kNoStateId) return kNoStateId; StateId start = this->FindState(Element(fst_->Start(), Weight::One())); this->SetStart(start); } return CacheImpl::Start(); } Weight Final(StateId s) { if (!HasFinal(s)) { const Element &e = elements_[s]; Weight w; if (e.state == kNoStateId) { // extra state inserted to represent final weights. FactorIterator fit(e.weight); if (fit.Done()) { // cannot be factored. w = e.weight; // so it's final } else { w = Weight::Zero(); // need another transition. } } else { if (e.weight != Weight::One()) { // Not a real state. w = Weight::Zero(); } else { // corresponds to a "real" state. w = fst_->Final(e.state); FactorIterator fit(w); if (!fit.Done()) // we would have intermediate states representing this final state. w = Weight::Zero(); } } this->SetFinal(s, w); return w; } else { return CacheImpl::Final(s); } } size_t NumArcs(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumArcs(s); } size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumInputEpsilons(s); } size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumOutputEpsilons(s); } void InitArcIterator(StateId s, ArcIteratorData *data) { if (!HasArcs(s)) Expand(s); CacheImpl::InitArcIterator(s, data); } // Find state corresponding to an element. Create new state // if element not found. StateId FindState(const Element &e) { typename ElementMap::iterator eit = element_map_.find(e); if (eit != element_map_.end()) { return (*eit).second; } else { StateId s = elements_.size(); elements_.push_back(e); element_map_.insert(pair(e, s)); return s; } } // Computes the outgoing transitions from a state, creating new destination // states as needed. void Expand(StateId s) { CHECK(static_cast(s) < elements_.size()); Element e = elements_[s]; if (e.weight != Weight::One()) { FactorIterator fit(e.weight); if (fit.Done()) { // Cannot be factored-> create a link to dest state directly if (e.state != kNoStateId) { StateId dest = FindState(Element(e.state, Weight::One())); PushArc(s, Arc(extra_ilabel_, extra_olabel_, e.weight, dest)); } // else we're done. This is a final state. } else { // Can be factored. const pair &p = fit.Value(); StateId dest = FindState(Element(e.state, p.second.Quantize(delta_))); PushArc(s, Arc(extra_ilabel_, extra_olabel_, p.first, dest)); } } else { // Unit weight. This corresponds to a "real" state. CHECK(e.state != kNoStateId); for (ArcIterator< Fst > ait(*fst_, e.state); !ait.Done(); ait.Next()) { const A &arc = ait.Value(); FactorIterator fit(arc.weight); if (fit.Done()) { // cannot be factored->just link directly to dest. StateId dest = FindState(Element(arc.nextstate, Weight::One())); PushArc(s, Arc(arc.ilabel, arc.olabel, arc.weight, dest)); } else { const pair &p = fit.Value(); StateId dest = FindState(Element(arc.nextstate, p.second.Quantize(delta_))); PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, dest)); } } // See if we have to add arcs for final-states [only if final-weight is factorable]. Weight final_w = fst_->Final(e.state); if (final_w != Weight::Zero()) { FactorIterator fit(final_w); if (!fit.Done()) { const pair &p = fit.Value(); StateId dest = FindState(Element(kNoStateId, p.second.Quantize(delta_))); PushArc(s, Arc(extra_ilabel_, extra_olabel_, p.first, dest)); } } } this->SetArcs(s); } private: // Equality function for Elements, assume weights have been quantized. class ElementEqual { public: bool operator()(const Element &x, const Element &y) const { return x.state == y.state && x.weight == y.weight; } }; // Hash function for Elements to Fst states. class ElementKey { public: size_t operator()(const Element &x) const { return static_cast(x.state * kPrime + x.weight.Hash()); } private: static const int kPrime = 7853; }; typedef unordered_map ElementMap; std::unique_ptr> fst_; float delta_; uint32 mode_; // factoring arc and/or final weights Label extra_ilabel_; // ilabel of arc created when factoring final w's Label extra_olabel_; // olabel of arc created when factoring final w's vector elements_; // mapping Fst state to Elements ElementMap element_map_; // mapping Elements to Fst state }; } // namespace internal /// TrivialFactorWeightFst takes as template parameter a FactorIterator as /// defined above. The result of weight factoring is a transducer /// equivalent to the input whose path weights have been factored /// according to the FactorIterator. States and transitions will be /// added as necessary. /// This algorithm differs from the one implemented in FactorWeightFst /// in that it does not attempt to push the extra weight forward to the /// next state: it uses a sequence of "extra" intermediate state, and /// outputs the remaining weight right away. This ensures that it will /// always succeed, even for Gallic representations of FSTs that have cycles /// with more output than input symbols. /// Note that the code below was modified from factor-weight.h by just /// search-and-replacing "FactorWeight" by "TrivialFactorWeight". template class TrivialFactorWeightFst : public ImplToFst> { public: friend class ArcIterator< TrivialFactorWeightFst >; friend class StateIterator< TrivialFactorWeightFst >; typedef A Arc; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef DefaultCacheStore Store; typedef typename Store::State State; typedef internal::TrivialFactorWeightFstImpl Impl; explicit TrivialFactorWeightFst(const Fst &fst) : ImplToFst(std::make_shared(fst, TrivialFactorWeightOptions())) {} TrivialFactorWeightFst(const Fst &fst, const TrivialFactorWeightOptions &opts) : ImplToFst(std::make_shared(fst, opts)) {} // See Fst<>::Copy() for doc. TrivialFactorWeightFst(const TrivialFactorWeightFst &fst, bool copy) : ImplToFst(fst, copy) {} // Get a copy of this TrivialFactorWeightFst. See Fst<>::Copy() for further doc. TrivialFactorWeightFst *Copy(bool copy = false) const override { return new TrivialFactorWeightFst(*this, copy); } inline void InitStateIterator(StateIteratorData *data) const override; void InitArcIterator(StateId s, ArcIteratorData *data) const override { GetMutableImpl()->InitArcIterator(s, data); } private: using ImplToFst::GetImpl; using ImplToFst::GetMutableImpl; TrivialFactorWeightFst &operator=(const TrivialFactorWeightFst &fst) = delete; }; // Specialization for TrivialFactorWeightFst. template class StateIterator< TrivialFactorWeightFst > : public CacheStateIterator< TrivialFactorWeightFst > { public: explicit StateIterator(const TrivialFactorWeightFst &fst) : CacheStateIterator< TrivialFactorWeightFst >(fst, fst.GetMutableImpl()) {} }; // Specialization for TrivialFactorWeightFst. template class ArcIterator< TrivialFactorWeightFst > : public CacheArcIterator< TrivialFactorWeightFst > { public: typedef typename A::StateId StateId; ArcIterator(const TrivialFactorWeightFst &fst, StateId s) : CacheArcIterator< TrivialFactorWeightFst>(fst.GetMutableImpl(), s) { if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); } }; template inline void TrivialFactorWeightFst::InitStateIterator( StateIteratorData *data) const { data->base = new StateIterator< TrivialFactorWeightFst >(*this); } } // namespace fst #endif