// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Class to add a matcher to an FST. #ifndef FST_MATCHER_FST_H_ #define FST_MATCHER_FST_H_ #include #include #include #include #include namespace fst { // Writeable matchers have the same interface as Matchers (as defined in // matcher.h) along with the following additional methods: // // template // class Matcher { // public: // using FST = F; // ... // using MatcherData = ...; // Initialization data. // // // Constructor with additional argument for external initialization data; // // matcher increments its reference count on construction and decrements // // the reference count, and deletes once the reference count has reached // // zero. // Matcher(const FST &fst, MatchType type, MatcherData *data); // // // Returns pointer to initialization data that can be passed to a Matcher // // constructor. // MatcherData *GetData() const; // }; // The matcher initialization data class must also provide the following // interface: // // class MatcherData { // public: // // Required copy constructor. // MatcherData(const MatcherData &); // // // Required I/O methods. // static MatcherData *Read(std::istream &istrm, const FstReadOptions &opts); // bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const; // }; // Trivial (no-op) MatcherFst initializer functor. template class NullMatcherFstInit { public: using MatcherData = typename M::MatcherData; using Data = AddOnPair; using Impl = internal::AddOnImpl; explicit NullMatcherFstInit(std::shared_ptr *) {} }; // Class adding a matcher to an FST type. Creates a new FST whose name is given // by N. An optional functor Init can be used to initialize the FST. The Data // template parameter allows the user to select the type of the add-on. template < class F, class M, const char *Name, class Init = NullMatcherFstInit, class Data = AddOnPair> class MatcherFst : public ImplToExpandedFst> { public: using FST = F; using Arc = typename FST::Arc; using StateId = typename Arc::StateId; using FstMatcher = M; using MatcherData = typename FstMatcher::MatcherData; using Impl = internal::AddOnImpl; using D = Data; friend class StateIterator>; friend class ArcIterator>; MatcherFst() : ImplToExpandedFst(std::make_shared(FST(), Name)) {} explicit MatcherFst(const FST &fst, std::shared_ptr data = nullptr) : ImplToExpandedFst(data ? CreateImpl(fst, Name, data) : CreateDataAndImpl(fst, Name)) {} explicit MatcherFst(const Fst &fst) : ImplToExpandedFst(CreateDataAndImpl(fst, Name)) {} // See Fst<>::Copy() for doc. MatcherFst(const MatcherFst &fst, bool safe = false) : ImplToExpandedFst(fst, safe) {} // Get a copy of this MatcherFst. See Fst<>::Copy() for further doc. MatcherFst *Copy( bool safe = false) const override { return new MatcherFst(*this, safe); } // Read a MatcherFst from an input stream; return nullptr on error static MatcherFst *Read( std::istream &strm, const FstReadOptions &opts) { auto *impl = Impl::Read(strm, opts); return impl ? new MatcherFst( std::shared_ptr(impl)) : nullptr; } // Read a MatcherFst from a file; return nullptr on error // Empty filename reads from standard input static MatcherFst *Read( const string &filename) { auto *impl = ImplToExpandedFst::Read(filename); return impl ? new MatcherFst( std::shared_ptr(impl)) : nullptr; } bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { return GetImpl()->Write(strm, opts); } bool Write(const string &filename) const override { return Fst::WriteFile(filename); } void InitStateIterator(StateIteratorData *data) const override { return GetImpl()->InitStateIterator(data); } void InitArcIterator(StateId s, ArcIteratorData *data) const override { return GetImpl()->InitArcIterator(s, data); } FstMatcher *InitMatcher(MatchType match_type) const override { return new FstMatcher(&GetFst(), match_type, GetSharedData(match_type)); } const FST &GetFst() const { return GetImpl()->GetFst(); } const Data *GetAddOn() const { return GetImpl()->GetAddOn(); } std::shared_ptr GetSharedAddOn() const { return GetImpl()->GetSharedAddOn(); } const MatcherData *GetData(MatchType match_type) const { const auto *data = GetAddOn(); return match_type == MATCH_INPUT ? data->First() : data->Second(); } std::shared_ptr GetSharedData(MatchType match_type) const { const auto *data = GetAddOn(); return match_type == MATCH_INPUT ? data->SharedFirst() : data->SharedSecond(); } protected: using ImplToFst>::GetImpl; static std::shared_ptr CreateDataAndImpl(const FST &fst, const string &name) { FstMatcher imatcher(fst, MATCH_INPUT); FstMatcher omatcher(fst, MATCH_OUTPUT); return CreateImpl(fst, name, std::make_shared(imatcher.GetSharedData(), omatcher.GetSharedData())); } static std::shared_ptr CreateDataAndImpl(const Fst &fst, const string &name) { FST result(fst); return CreateDataAndImpl(result, name); } static std::shared_ptr CreateImpl(const FST &fst, const string &name, std::shared_ptr data) { auto impl = std::make_shared(fst, name); impl->SetAddOn(data); Init init(&impl); return impl; } explicit MatcherFst(std::shared_ptr impl) : ImplToExpandedFst(impl) {} private: MatcherFst &operator=(const MatcherFst &) = delete; }; // Specialization for MatcherFst. template class StateIterator> : public StateIterator { public: explicit StateIterator(const MatcherFst &fst) : StateIterator(fst.GetImpl()->GetFst()) {} }; // Specialization for MatcherFst. template class ArcIterator> : public ArcIterator { public: using StateId = typename FST::Arc::StateId; ArcIterator(const MatcherFst &fst, typename FST::Arc::StateId s) : ArcIterator(fst.GetImpl()->GetFst(), s) {} }; // Specialization for MatcherFst. template class Matcher> { public: using FST = MatcherFst; using Arc = typename F::Arc; using Label = typename Arc::Label; using StateId = typename Arc::StateId; Matcher(const FST &fst, MatchType match_type) : matcher_(fst.InitMatcher(match_type)) {} Matcher(const Matcher &matcher) : matcher_(matcher.matcher_->Copy()) {} Matcher *Copy() const { return new Matcher(*this); } MatchType Type(bool test) const { return matcher_->Type(test); } void SetState(StateId s) { matcher_->SetState(s); } bool Find(Label label) { return matcher_->Find(label); } bool Done() const { return matcher_->Done(); } const Arc &Value() const { return matcher_->Value(); } void Next() { matcher_->Next(); } uint64 Properties(uint64 props) const { return matcher_->Properties(props); } uint32 Flags() const { return matcher_->Flags(); } private: std::unique_ptr matcher_; }; // Specialization for MatcherFst. template class LookAheadMatcher> { public: using FST = MatcherFst; using Arc = typename F::Arc; using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; LookAheadMatcher(const FST &fst, MatchType match_type) : matcher_(fst.InitMatcher(match_type)) {} LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false) : matcher_(matcher.matcher_->Copy(safe)) {} // General matcher methods. LookAheadMatcher *Copy(bool safe = false) const { return new LookAheadMatcher(*this, safe); } MatchType Type(bool test) const { return matcher_->Type(test); } void SetState(StateId s) { matcher_->SetState(s); } bool Find(Label label) { return matcher_->Find(label); } bool Done() const { return matcher_->Done(); } const Arc &Value() const { return matcher_->Value(); } void Next() { matcher_->Next(); } const FST &GetFst() const { return matcher_->GetFst(); } uint64 Properties(uint64 props) const { return matcher_->Properties(props); } uint32 Flags() const { return matcher_->Flags(); } bool LookAheadLabel(Label label) const { return matcher_->LookAheadLabel(label); } bool LookAheadFst(const Fst &fst, StateId s) { return matcher_->LookAheadFst(fst, s); } Weight LookAheadWeight() const { return matcher_->LookAheadWeight(); } bool LookAheadPrefix(Arc *arc) const { return matcher_->LookAheadPrefix(arc); } void InitLookAheadFst(const Fst &fst, bool copy = false) { matcher_->InitLookAheadFst(fst, copy); } private: std::unique_ptr matcher_; }; // Useful aliases when using StdArc. extern const char arc_lookahead_fst_type[]; using StdArcLookAheadFst = MatcherFst, ArcLookAheadMatcher>>, arc_lookahead_fst_type>; extern const char ilabel_lookahead_fst_type[]; extern const char olabel_lookahead_fst_type[]; constexpr auto ilabel_lookahead_flags = kInputLookAheadMatcher | kLookAheadWeight | kLookAheadPrefix | kLookAheadEpsilons | kLookAheadNonEpsilonPrefix; constexpr auto olabel_lookahead_flags = kOutputLookAheadMatcher | kLookAheadWeight | kLookAheadPrefix | kLookAheadEpsilons | kLookAheadNonEpsilonPrefix; using StdILabelLookAheadFst = MatcherFst< ConstFst, LabelLookAheadMatcher>, ilabel_lookahead_flags, FastLogAccumulator>, ilabel_lookahead_fst_type, LabelLookAheadRelabeler>; using StdOLabelLookAheadFst = MatcherFst< ConstFst, LabelLookAheadMatcher>, olabel_lookahead_flags, FastLogAccumulator>, olabel_lookahead_fst_type, LabelLookAheadRelabeler>; } // namespace fst #endif // FST_MATCHER_FST_H_