// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Functions and classes that implemement epsilon-removal. #ifndef FST_RMEPSILON_H_ #define FST_RMEPSILON_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace fst { template class RmEpsilonOptions : public ShortestDistanceOptions> { public: using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; bool connect; // Connect output Weight weight_threshold; // Pruning weight threshold. StateId state_threshold; // Pruning state threshold. explicit RmEpsilonOptions(Queue *queue, float delta = kShortestDelta, bool connect = true, Weight weight_threshold = Weight::Zero(), StateId state_threshold = kNoStateId) : ShortestDistanceOptions>( queue, EpsilonArcFilter(), kNoStateId, delta), connect(connect), weight_threshold(std::move(weight_threshold)), state_threshold(state_threshold) {} }; namespace internal { // Computation state of the epsilon-removal algorithm. template class RmEpsilonState { public: using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; RmEpsilonState(const Fst &fst, std::vector *distance, const RmEpsilonOptions &opts) : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true), expand_id_(0) {} void Expand(StateId s); std::vector &Arcs() { return arcs_; } const Weight &Final() const { return final_; } bool Error() const { return sd_state_.Error(); } private: struct Element { Label ilabel; Label olabel; StateId nextstate; Element() {} Element(Label ilabel, Label olabel, StateId nexstate) : ilabel(ilabel), olabel(olabel), nextstate(nexstate) {} }; struct ElementHash { public: size_t operator()(const Element &element) const { static constexpr size_t prime0 = 7853; static constexpr size_t prime1 = 7867; return static_cast(element.nextstate) + static_cast(element.ilabel) * prime0 + static_cast(element.olabel) * prime1; } }; class ElementEqual { public: bool operator()(const Element &e1, const Element &e2) const { return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) && (e1.nextstate == e2.nextstate); } }; using ElementMap = std::unordered_map, ElementHash, ElementEqual>; const Fst &fst_; // Distance from state being expanded in epsilon-closure. std::vector *distance_; // Shortest distance algorithm computation state. internal::ShortestDistanceState> sd_state_; // Maps an element to a pair corresponding to a position in the arcs vector // of the state being expanded. The element corresopnds to the position in // the arcs_ vector if p.first is equal to the state being expanded. ElementMap element_map_; EpsilonArcFilter eps_filter_; std::stack eps_queue_; // Queue used to visit the epsilon-closure. std::vector visited_; // True if the state has been visited. std::forward_list visited_states_; // List of visited states. std::vector arcs_; // Arcs of state being expanded. Weight final_; // Final weight of state being expanded. StateId expand_id_; // Unique ID for each call to Expand RmEpsilonState(const RmEpsilonState &) = delete; RmEpsilonState &operator=(const RmEpsilonState &) = delete; }; template void RmEpsilonState::Expand(typename Arc::StateId source) { final_ = Weight::Zero(); arcs_.clear(); sd_state_.ShortestDistance(source); if (sd_state_.Error()) return; eps_queue_.push(source); while (!eps_queue_.empty()) { const auto state = eps_queue_.top(); eps_queue_.pop(); while (visited_.size() <= state) visited_.push_back(false); if (visited_[state]) continue; visited_[state] = true; visited_states_.push_front(state); for (ArcIterator> aiter(fst_, state); !aiter.Done(); aiter.Next()) { auto arc = aiter.Value(); arc.weight = Times((*distance_)[state], arc.weight); if (eps_filter_(arc)) { while (visited_.size() <= arc.nextstate) visited_.push_back(false); if (!visited_[arc.nextstate]) eps_queue_.push(arc.nextstate); } else { const Element element(arc.ilabel, arc.olabel, arc.nextstate); auto insert_result = element_map_.insert( std::make_pair(element, std::make_pair(expand_id_, arcs_.size()))); if (insert_result.second) { arcs_.push_back(arc); } else { if (insert_result.first->second.first == expand_id_) { auto &weight = arcs_[insert_result.first->second.second].weight; weight = Plus(weight, arc.weight); } else { insert_result.first->second.first = expand_id_; insert_result.first->second.second = arcs_.size(); arcs_.push_back(arc); } } } } final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state))); } while (!visited_states_.empty()) { visited_[visited_states_.front()] = false; visited_states_.pop_front(); } ++expand_id_; } } // namespace internal // Removes epsilon-transitions (when both the input and output label are an // epsilon) from a transducer. The result will be an equivalent FST that has no // such epsilon transitions. This version modifies its input. It allows fine // control via the options argument; see below for a simpler interface. // // The distance vector will be used to hold the shortest distances during the // epsilon-closure computation. The state queue discipline and convergence delta // are taken in the options argument. template void RmEpsilon(MutableFst *fst, std::vector *distance, const RmEpsilonOptions &opts) { using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; if (fst->Start() == kNoStateId) return; // noneps_in[s] will be set to true iff s admits a non-epsilon incoming // transition or is the start state. std::vector noneps_in(fst->NumStates(), false); noneps_in[fst->Start()] = true; for (size_t i = 0; i < fst->NumStates(); ++i) { for (ArcIterator> aiter(*fst, i); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); if (arc.ilabel != 0 || arc.olabel != 0) { noneps_in[arc.nextstate] = true; } } } // States sorted in topological order when (acyclic) or generic topological // order (cyclic). std::vector states; states.reserve(fst->NumStates()); if (fst->Properties(kTopSorted, false) & kTopSorted) { for (size_t i = 0; i < fst->NumStates(); i++) states.push_back(i); } else if (fst->Properties(kAcyclic, false) & kAcyclic) { std::vector order; bool acyclic; TopOrderVisitor top_order_visitor(&order, &acyclic); DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter()); // Sanity check: should be acyclic if property bit is set. if (!acyclic) { FSTERROR() << "RmEpsilon: Inconsistent acyclic property bit"; fst->SetProperties(kError, kError); return; } states.resize(order.size()); for (StateId i = 0; i < order.size(); i++) states[order[i]] = i; } else { uint64 props; std::vector scc; SccVisitor scc_visitor(&scc, nullptr, nullptr, &props); DfsVisit(*fst, &scc_visitor, EpsilonArcFilter()); std::vector first(scc.size(), kNoStateId); std::vector next(scc.size(), kNoStateId); for (StateId i = 0; i < scc.size(); i++) { if (first[scc[i]] != kNoStateId) next[i] = first[scc[i]]; first[scc[i]] = i; } for (StateId i = 0; i < first.size(); i++) { for (auto j = first[i]; j != kNoStateId; j = next[j]) { states.push_back(j); } } } internal::RmEpsilonState rmeps_state(*fst, distance, opts); while (!states.empty()) { const auto state = states.back(); states.pop_back(); if (!noneps_in[state] && (opts.connect || opts.weight_threshold != Weight::Zero() || opts.state_threshold != kNoStateId)) { continue; } rmeps_state.Expand(state); fst->SetFinal(state, rmeps_state.Final()); fst->DeleteArcs(state); auto &arcs = rmeps_state.Arcs(); fst->ReserveArcs(state, arcs.size()); while (!arcs.empty()) { fst->AddArc(state, arcs.back()); arcs.pop_back(); } } if (opts.connect || opts.weight_threshold != Weight::Zero() || opts.state_threshold != kNoStateId) { for (size_t s = 0; s < fst->NumStates(); ++s) { if (!noneps_in[s]) fst->DeleteArcs(s); } } if (rmeps_state.Error()) fst->SetProperties(kError, kError); fst->SetProperties( RmEpsilonProperties(fst->Properties(kFstProperties, false)), kFstProperties); if (opts.weight_threshold != Weight::Zero() || opts.state_threshold != kNoStateId) { Prune(fst, opts.weight_threshold, opts.state_threshold); } if (opts.connect && opts.weight_threshold == Weight::Zero() && opts.state_threshold == kNoStateId) { Connect(fst); } } // Removes epsilon-transitions (when both the input and output label // are an epsilon) from a transducer. The result will be an equivalent // FST that has no such epsilon transitions. This version modifies its // input. It has a simplified interface; see above for a version that // allows finer control. // // Complexity: // // - Time: // // Unweighted: O(v^2 + ve). // Acyclic: O(v^2 + V e). // Tropical semiring: O(v^2 log V + ve). // General: exponential. // // - Space: O(vE) // // where v is the number of states visited and e is the number of arcs visited. // // For more information, see: // // Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization // algorithms for weighted transducers. International Journal of Computer // Science 13(1): 129-143. template void RmEpsilon(MutableFst *fst, bool connect = true, typename Arc::Weight weight_threshold = Arc::Weight::Zero(), typename Arc::StateId state_threshold = kNoStateId, float delta = kShortestDelta) { using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; std::vector distance; AutoQueue state_queue(*fst, &distance, EpsilonArcFilter()); RmEpsilonOptions> opts( &state_queue, delta, connect, weight_threshold, state_threshold); RmEpsilon(fst, &distance, opts); } struct RmEpsilonFstOptions : CacheOptions { float delta; explicit RmEpsilonFstOptions(const CacheOptions &opts, float delta = kShortestDelta) : CacheOptions(opts), delta(delta) {} explicit RmEpsilonFstOptions(float delta = kShortestDelta) : delta(delta) {} }; namespace internal { // Implementation of delayed RmEpsilonFst. template class RmEpsilonFstImpl : public CacheImpl { public: using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; using Store = DefaultCacheStore; using State = typename Store::State; using FstImpl::Properties; using FstImpl::SetType; using FstImpl::SetProperties; using FstImpl::SetInputSymbols; using FstImpl::SetOutputSymbols; using CacheBaseImpl>::HasArcs; using CacheBaseImpl>::HasFinal; using CacheBaseImpl>::HasStart; using CacheBaseImpl>::PushArc; using CacheBaseImpl>::SetArcs; using CacheBaseImpl>::SetFinal; using CacheBaseImpl>::SetStart; RmEpsilonFstImpl(const Fst &fst, const RmEpsilonFstOptions &opts) : CacheImpl(opts), fst_(fst.Copy()), delta_(opts.delta), rmeps_state_( *fst_, &distance_, RmEpsilonOptions>(&queue_, delta_, false)) { SetType("rmepsilon"); SetProperties( RmEpsilonProperties(fst.Properties(kFstProperties, false), true), kCopyProperties); SetInputSymbols(fst.InputSymbols()); SetOutputSymbols(fst.OutputSymbols()); } RmEpsilonFstImpl(const RmEpsilonFstImpl &impl) : CacheImpl(impl), fst_(impl.fst_->Copy(true)), delta_(impl.delta_), rmeps_state_( *fst_, &distance_, RmEpsilonOptions>(&queue_, delta_, false)) { SetType("rmepsilon"); SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(impl.InputSymbols()); SetOutputSymbols(impl.OutputSymbols()); } StateId Start() { if (!HasStart()) SetStart(fst_->Start()); return CacheImpl::Start(); } Weight Final(StateId s) { if (!HasFinal(s)) Expand(s); return CacheImpl::Final(s); } size_t NumArcs(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumArcs(s); } size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumInputEpsilons(s); } size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumOutputEpsilons(s); } uint64 Properties() const override { return Properties(kFstProperties); } // Sets error if found and returns other FST impl properties. uint64 Properties(uint64 mask) const override { if ((mask & kError) && (fst_->Properties(kError, false) || rmeps_state_.Error())) { SetProperties(kError, kError); } return FstImpl::Properties(mask); } void InitArcIterator(StateId s, ArcIteratorData *data) { if (!HasArcs(s)) Expand(s); CacheImpl::InitArcIterator(s, data); } void Expand(StateId s) { rmeps_state_.Expand(s); SetFinal(s, rmeps_state_.Final()); auto &arcs = rmeps_state_.Arcs(); while (!arcs.empty()) { PushArc(s, arcs.back()); arcs.pop_back(); } SetArcs(s); } private: std::unique_ptr> fst_; float delta_; std::vector distance_; FifoQueue queue_; internal::RmEpsilonState> rmeps_state_; }; } // namespace internal // Removes epsilon-transitions (when both the input and output label are an // epsilon) from a transducer. The result will be an equivalent FST that has no // such epsilon transitions. This version is a // delayed FST. // // Complexity: // // - Time: // Unweighted: O(v^2 + ve). // General: exponential. // // - Space: O(vE) // // where v is the number of states visited and e is the number of arcs visited. // Constant time to visit an input state or arc is assumed and exclusive of // caching. // // For more information, see: // // Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization // algorithms for weighted transducers. International Journal of Computer // Science 13(1): 129-143. // // This class attaches interface to implementation and handles // reference counting, delegating most methods to ImplToFst. template class RmEpsilonFst : public ImplToFst> { public: using Arc = A; using StateId = typename Arc::StateId; using Store = DefaultCacheStore; using State = typename Store::State; using Impl = internal::RmEpsilonFstImpl; friend class ArcIterator>; friend class StateIterator>; explicit RmEpsilonFst(const Fst &fst) : ImplToFst(std::make_shared(fst, RmEpsilonFstOptions())) {} RmEpsilonFst(const Fst &fst, const RmEpsilonFstOptions &opts) : ImplToFst(std::make_shared(fst, opts)) {} // See Fst<>::Copy() for doc. RmEpsilonFst(const RmEpsilonFst &fst, bool safe = false) : ImplToFst(fst, safe) {} // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc. RmEpsilonFst *Copy(bool safe = false) const override { return new RmEpsilonFst(*this, safe); } inline void InitStateIterator(StateIteratorData *data) const override; void InitArcIterator(StateId s, ArcIteratorData *data) const override { GetMutableImpl()->InitArcIterator(s, data); } private: using ImplToFst::GetImpl; using ImplToFst::GetMutableImpl; RmEpsilonFst &operator=(const RmEpsilonFst &) = delete; }; // Specialization for RmEpsilonFst. template class StateIterator> : public CacheStateIterator> { public: explicit StateIterator(const RmEpsilonFst &fst) : CacheStateIterator>(fst, fst.GetMutableImpl()) {} }; // Specialization for RmEpsilonFst. template class ArcIterator> : public CacheArcIterator> { public: using StateId = typename Arc::StateId; ArcIterator(const RmEpsilonFst &fst, StateId s) : CacheArcIterator>(fst.GetMutableImpl(), s) { if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); } }; template inline void RmEpsilonFst::InitStateIterator( StateIteratorData *data) const { data->base = new StateIterator>(*this); } // Useful alias when using StdArc. using StdRmEpsilonFst = RmEpsilonFst; } // namespace fst #endif // FST_RMEPSILON_H_