// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Classes for building, storing and representing log-linear models as FSTs. #ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_H_ #define FST_EXTENSIONS_LINEAR_LINEAR_FST_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace fst { // Forward declaration of the specialized matcher for both // LinearTaggerFst and LinearClassifierFst. template class LinearFstMatcherTpl; namespace internal { // Implementation class for on-the-fly generated LinearTaggerFst with // special optimization in matching. template class LinearTaggerFstImpl : public CacheImpl { public: using FstImpl::SetType; using FstImpl::SetProperties; using FstImpl::SetInputSymbols; using FstImpl::SetOutputSymbols; using FstImpl::WriteHeader; using CacheBaseImpl>::PushArc; using CacheBaseImpl>::HasArcs; using CacheBaseImpl>::HasFinal; using CacheBaseImpl>::HasStart; using CacheBaseImpl>::SetArcs; using CacheBaseImpl>::SetFinal; using CacheBaseImpl>::SetStart; typedef A Arc; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef typename Collection::SetIterator NGramIterator; // Constructs an empty FST by default. LinearTaggerFstImpl() : CacheImpl(CacheOptions()), data_(std::make_shared>()), delay_(0) { SetType("linear-tagger"); } // Constructs the FST with given data storage and symbol // tables. // // TODO(wuke): when there is no constraint on output we can delay // less than `data->MaxFutureSize` positions. LinearTaggerFstImpl(const LinearFstData *data, const SymbolTable *isyms, const SymbolTable *osyms, CacheOptions opts) : CacheImpl(opts), data_(data), delay_(data->MaxFutureSize()) { SetType("linear-tagger"); SetProperties(kILabelSorted, kFstProperties); SetInputSymbols(isyms); SetOutputSymbols(osyms); ReserveStubSpace(); } // Copy by sharing the underlying data storage. LinearTaggerFstImpl(const LinearTaggerFstImpl &impl) : CacheImpl(impl), data_(impl.data_), delay_(impl.delay_) { SetType("linear-tagger"); SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(impl.InputSymbols()); SetOutputSymbols(impl.OutputSymbols()); ReserveStubSpace(); } StateId Start() { if (!HasStart()) { StateId start = FindStartState(); SetStart(start); } return CacheImpl::Start(); } Weight Final(StateId s) { if (!HasFinal(s)) { state_stub_.clear(); FillState(s, &state_stub_); if (CanBeFinal(state_stub_)) SetFinal(s, data_->FinalWeight(InternalBegin(state_stub_), InternalEnd(state_stub_))); else SetFinal(s, Weight::Zero()); } return CacheImpl::Final(s); } size_t NumArcs(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumArcs(s); } size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumInputEpsilons(s); } size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumOutputEpsilons(s); } void InitArcIterator(StateId s, ArcIteratorData *data) { if (!HasArcs(s)) Expand(s); CacheImpl::InitArcIterator(s, data); } // Computes the outgoing transitions from a state, creating new // destination states as needed. void Expand(StateId s); // Appends to `arcs` all out-going arcs from state `s` that matches `label` as // the input label. void MatchInput(StateId s, Label ilabel, std::vector *arcs); static LinearTaggerFstImpl *Read(std::istream &strm, const FstReadOptions &opts); bool Write(std::ostream &strm, // NOLINT const FstWriteOptions &opts) const { FstHeader header; header.SetStart(kNoStateId); WriteHeader(strm, opts, kFileVersion, &header); data_->Write(strm); if (!strm) { LOG(ERROR) << "LinearTaggerFst::Write: Write failed: " << opts.source; return false; } return true; } private: static const int kMinFileVersion; static const int kFileVersion; // A collection of functions to access parts of the state tuple. A // state tuple is a vector of `Label`s with two parts: // [buffer] [internal]. // // - [buffer] is a buffer of observed input labels with length // `delay_`. `LinearFstData::kStartOfSentence` // (resp. `LinearFstData::kEndOfSentence`) are used as // paddings when the buffer has fewer than `delay_` elements, which // can only appear as the prefix (resp. suffix) of the buffer. // // - [internal] is the internal state tuple for `LinearFstData` typename std::vector::kStartOfSentence); // Append internal states data_->EncodeStartState(&state_stub_); return FindState(state_stub_); } // Tests whether the buffer in `(begin, end)` is empty. bool IsEmptyBuffer(typename std::vector::kEndOfSentence => // buffer[i+x] == LinearFstData::kEndOfSentence // - buffer[i] == LinearFstData::kStartOfSentence => // buffer[i-x] == LinearFstData::kStartOfSentence return delay_ == 0 || *(end - 1) == LinearFstData::kStartOfSentence || *begin == LinearFstData::kEndOfSentence; } // Tests whether the given state tuple can be a final state. A state // is final iff there is no observed input in the buffer. bool CanBeFinal(const std::vector::kMinFileVersion = 1; template const int LinearTaggerFstImpl::kFileVersion = 1; template inline typename A::Label LinearTaggerFstImpl::ShiftBuffer( const std::vector::kEndOfSentence); if (delay_ == 0) { DCHECK_GT(ilabel, 0); return ilabel; } else { (*next_stub_)[BufferEnd(*next_stub_) - next_stub_->begin() - 1] = ilabel; return *BufferBegin(state); } } template inline A LinearTaggerFstImpl::MakeArc(const std::vector::kEndOfSentence); DCHECK(olabel > 0 || olabel == LinearFstData::kStartOfSentence); Weight weight(Weight::One()); data_->TakeTransition(BufferEnd(state), InternalBegin(state), InternalEnd(state), ilabel, olabel, next_stub_, &weight); StateId nextstate = FindState(*next_stub_); // Restore `next_stub_` to its size before the call next_stub_->resize(delay_); // In the actual arc, we use epsilons instead of boundaries. return A(ilabel == LinearFstData::kEndOfSentence ? 0 : ilabel, olabel == LinearFstData::kStartOfSentence ? 0 : olabel, weight, nextstate); } template inline void LinearTaggerFstImpl::ExpandArcs(StateId s, const std::vector::kStartOfSentence) { // This happens when input is shorter than `delay_`. PushArc(s, MakeArc(state, ilabel, LinearFstData::kStartOfSentence, next_stub_)); } else { std::pair::const_iterator, typename std::vector::const_iterator> range = data_->PossibleOutputLabels(obs_ilabel); for (typename std::vector::const_iterator it = range.first; it != range.second; ++it) PushArc(s, MakeArc(state, ilabel, *it, next_stub_)); } } // TODO(wuke): this has much in duplicate with `ExpandArcs()` template inline void LinearTaggerFstImpl::AppendArcs(StateId /*s*/, const std::vector::kStartOfSentence) { // This happens when input is shorter than `delay_`. arcs->push_back( MakeArc(state, ilabel, LinearFstData::kStartOfSentence, next_stub_)); } else { std::pair::const_iterator, typename std::vector::const_iterator> range = data_->PossibleOutputLabels(obs_ilabel); for (typename std::vector::const_iterator it = range.first; it != range.second; ++it) arcs->push_back(MakeArc(state, ilabel, *it, next_stub_)); } } template void LinearTaggerFstImpl::Expand(StateId s) { VLOG(3) << "Expand " << s; state_stub_.clear(); FillState(s, &state_stub_); // Precompute the first `delay_ - 1` elements in the buffer of // next states, which are identical for different input/output. next_stub_.clear(); next_stub_.resize(delay_); if (delay_ > 0) std::copy(BufferBegin(state_stub_) + 1, BufferEnd(state_stub_), next_stub_.begin()); // Epsilon transition for flushing out the next observed input if (!IsEmptyBuffer(BufferBegin(state_stub_), BufferEnd(state_stub_))) ExpandArcs(s, state_stub_, LinearFstData::kEndOfSentence, &next_stub_); // Non-epsilon input when we haven't flushed if (delay_ == 0 || *(BufferEnd(state_stub_) - 1) != LinearFstData::kEndOfSentence) for (Label ilabel = data_->MinInputLabel(); ilabel <= data_->MaxInputLabel(); ++ilabel) ExpandArcs(s, state_stub_, ilabel, &next_stub_); SetArcs(s); } template void LinearTaggerFstImpl::MatchInput(StateId s, Label ilabel, std::vector *arcs) { state_stub_.clear(); FillState(s, &state_stub_); // Precompute the first `delay_ - 1` elements in the buffer of // next states, which are identical for different input/output. next_stub_.clear(); next_stub_.resize(delay_); if (delay_ > 0) std::copy(BufferBegin(state_stub_) + 1, BufferEnd(state_stub_), next_stub_.begin()); if (ilabel == 0) { // Epsilon transition for flushing out the next observed input if (!IsEmptyBuffer(BufferBegin(state_stub_), BufferEnd(state_stub_))) AppendArcs(s, state_stub_, LinearFstData::kEndOfSentence, &next_stub_, arcs); } else { // Non-epsilon input when we haven't flushed if (delay_ == 0 || *(BufferEnd(state_stub_) - 1) != LinearFstData::kEndOfSentence) AppendArcs(s, state_stub_, ilabel, &next_stub_, arcs); } } template inline LinearTaggerFstImpl *LinearTaggerFstImpl::Read( std::istream &strm, const FstReadOptions &opts) { // NOLINT std::unique_ptr> impl(new LinearTaggerFstImpl()); FstHeader header; if (!impl->ReadHeader(strm, opts, kMinFileVersion, &header)) { return nullptr; } impl->data_ = std::shared_ptr>(LinearFstData::Read(strm)); if (!impl->data_) { return nullptr; } impl->delay_ = impl->data_->MaxFutureSize(); impl->ReserveStubSpace(); return impl.release(); } } // namespace internal // This class attaches interface to implementation and handles // reference counting, delegating most methods to ImplToFst. template class LinearTaggerFst : public ImplToFst> { public: friend class ArcIterator>; friend class StateIterator>; friend class LinearFstMatcherTpl>; typedef A Arc; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef DefaultCacheStore Store; typedef typename Store::State State; using Impl = internal::LinearTaggerFstImpl; LinearTaggerFst() : ImplToFst(std::make_shared()) {} explicit LinearTaggerFst(LinearFstData *data, const SymbolTable *isyms = nullptr, const SymbolTable *osyms = nullptr, CacheOptions opts = CacheOptions()) : ImplToFst(std::make_shared(data, isyms, osyms, opts)) {} explicit LinearTaggerFst(const Fst &fst) : ImplToFst(std::make_shared()) { LOG(FATAL) << "LinearTaggerFst: no constructor from arbitrary FST."; } // See Fst<>::Copy() for doc. LinearTaggerFst(const LinearTaggerFst &fst, bool safe = false) : ImplToFst(fst, safe) {} // Get a copy of this LinearTaggerFst. See Fst<>::Copy() for further doc. LinearTaggerFst *Copy(bool safe = false) const override { return new LinearTaggerFst(*this, safe); } inline void InitStateIterator(StateIteratorData *data) const override; void InitArcIterator(StateId s, ArcIteratorData *data) const override { GetMutableImpl()->InitArcIterator(s, data); } MatcherBase *InitMatcher(MatchType match_type) const override { return new LinearFstMatcherTpl>(this, match_type); } static LinearTaggerFst *Read(const string &filename) { if (!filename.empty()) { std::ifstream strm(filename, std::ios_base::in | std::ios_base::binary); if (!strm) { LOG(ERROR) << "LinearTaggerFst::Read: Can't open file: " << filename; return nullptr; } return Read(strm, FstReadOptions(filename)); } else { return Read(std::cin, FstReadOptions("standard input")); } } static LinearTaggerFst *Read(std::istream &in, // NOLINT const FstReadOptions &opts) { auto *impl = Impl::Read(in, opts); return impl ? new LinearTaggerFst(std::shared_ptr(impl)) : nullptr; } bool Write(const string &filename) const override { if (!filename.empty()) { std::ofstream strm(filename, std::ios_base::out | std::ios_base::binary); if (!strm) { LOG(ERROR) << "LinearTaggerFst::Write: Can't open file: " << filename; return false; } return Write(strm, FstWriteOptions(filename)); } else { return Write(std::cout, FstWriteOptions("standard output")); } } bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { return GetImpl()->Write(strm, opts); } private: using ImplToFst::GetImpl; using ImplToFst::GetMutableImpl; explicit LinearTaggerFst(std::shared_ptr impl) : ImplToFst(impl) {} void operator=(const LinearTaggerFst &fst) = delete; }; // Specialization for LinearTaggerFst. template class StateIterator> : public CacheStateIterator> { public: explicit StateIterator(const LinearTaggerFst &fst) : CacheStateIterator>(fst, fst.GetMutableImpl()) {} }; // Specialization for LinearTaggerFst. template class ArcIterator> : public CacheArcIterator> { public: using StateId = typename Arc::StateId; ArcIterator(const LinearTaggerFst &fst, StateId s) : CacheArcIterator>(fst.GetMutableImpl(), s) { if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); } }; template inline void LinearTaggerFst::InitStateIterator( StateIteratorData *data) const { data->base = new StateIterator>(*this); } namespace internal { // Implementation class for on-the-fly generated LinearClassifierFst with // special optimization in matching. template class LinearClassifierFstImpl : public CacheImpl { public: using FstImpl::SetType; using FstImpl::SetProperties; using FstImpl::SetInputSymbols; using FstImpl::SetOutputSymbols; using FstImpl::WriteHeader; using CacheBaseImpl>::PushArc; using CacheBaseImpl>::HasArcs; using CacheBaseImpl>::HasFinal; using CacheBaseImpl>::HasStart; using CacheBaseImpl>::SetArcs; using CacheBaseImpl>::SetFinal; using CacheBaseImpl>::SetStart; typedef A Arc; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef typename Collection::SetIterator NGramIterator; // Constructs an empty FST by default. LinearClassifierFstImpl() : CacheImpl(CacheOptions()), data_(std::make_shared>()) { SetType("linear-classifier"); num_classes_ = 0; num_groups_ = 0; } // Constructs the FST with given data storage, number of classes and // symbol tables. LinearClassifierFstImpl(const LinearFstData *data, size_t num_classes, const SymbolTable *isyms, const SymbolTable *osyms, CacheOptions opts) : CacheImpl(opts), data_(data), num_classes_(num_classes), num_groups_(data_->NumGroups() / num_classes_) { SetType("linear-classifier"); SetProperties(kILabelSorted, kFstProperties); SetInputSymbols(isyms); SetOutputSymbols(osyms); ReserveStubSpace(); } // Copy by sharing the underlying data storage. LinearClassifierFstImpl(const LinearClassifierFstImpl &impl) : CacheImpl(impl), data_(impl.data_), num_classes_(impl.num_classes_), num_groups_(impl.num_groups_) { SetType("linear-classifier"); SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(impl.InputSymbols()); SetOutputSymbols(impl.OutputSymbols()); ReserveStubSpace(); } StateId Start() { if (!HasStart()) { StateId start = FindStartState(); SetStart(start); } return CacheImpl::Start(); } Weight Final(StateId s) { if (!HasFinal(s)) { state_stub_.clear(); FillState(s, &state_stub_); SetFinal(s, FinalWeight(state_stub_)); } return CacheImpl::Final(s); } size_t NumArcs(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumArcs(s); } size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumInputEpsilons(s); } size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumOutputEpsilons(s); } void InitArcIterator(StateId s, ArcIteratorData *data) { if (!HasArcs(s)) Expand(s); CacheImpl::InitArcIterator(s, data); } // Computes the outgoing transitions from a state, creating new // destination states as needed. void Expand(StateId s); // Appends to `arcs` all out-going arcs from state `s` that matches // `label` as the input label. void MatchInput(StateId s, Label ilabel, std::vector *arcs); static LinearClassifierFstImpl *Read(std::istream &strm, const FstReadOptions &opts); bool Write(std::ostream &strm, const FstWriteOptions &opts) const { FstHeader header; header.SetStart(kNoStateId); WriteHeader(strm, opts, kFileVersion, &header); data_->Write(strm); WriteType(strm, num_classes_); if (!strm) { LOG(ERROR) << "LinearClassifierFst::Write: Write failed: " << opts.source; return false; } return true; } private: static const int kMinFileVersion; static const int kFileVersion; // A collection of functions to access parts of the state tuple. A // state tuple is a vector of `Label`s with two parts: // [prediction] [internal]. // // - [prediction] is a single label of the predicted class. A state // must have a positive class label, unless it is the start state. // // - [internal] is the internal state tuple for `LinearFstData` of // the given class; or kNoTrieNodeId's if in start state. Label &Prediction(std::vector &) = delete; }; template const int LinearClassifierFstImpl::kMinFileVersion = 0; template const int LinearClassifierFstImpl::kFileVersion = 0; template void LinearClassifierFstImpl::Expand(StateId s) { VLOG(3) << "Expand " << s; state_stub_.clear(); FillState(s, &state_stub_); next_stub_.clear(); next_stub_.resize(1 + num_groups_); if (IsStartState(state_stub_)) { // Make prediction for (Label pred = 1; pred <= num_classes_; ++pred) { Prediction(next_stub_) = pred; for (int i = 0; i < num_groups_; ++i) InternalAt(next_stub_, i) = data_->GroupStartState(GroupId(pred, i)); PushArc(s, A(0, pred, Weight::One(), FindState(next_stub_))); } } else { Label pred = Prediction(state_stub_); DCHECK_GT(pred, 0); DCHECK_LE(pred, num_classes_); for (Label ilabel = data_->MinInputLabel(); ilabel <= data_->MaxInputLabel(); ++ilabel) { Prediction(next_stub_) = pred; Weight weight = Weight::One(); for (int i = 0; i < num_groups_; ++i) InternalAt(next_stub_, i) = data_->GroupTransition(GroupId(pred, i), InternalAt(state_stub_, i), ilabel, pred, &weight); PushArc(s, A(ilabel, 0, weight, FindState(next_stub_))); } } SetArcs(s); } template void LinearClassifierFstImpl::MatchInput(StateId s, Label ilabel, std::vector *arcs) { state_stub_.clear(); FillState(s, &state_stub_); next_stub_.clear(); next_stub_.resize(1 + num_groups_); if (IsStartState(state_stub_)) { // Make prediction if `ilabel` is epsilon. if (ilabel == 0) { for (Label pred = 1; pred <= num_classes_; ++pred) { Prediction(next_stub_) = pred; for (int i = 0; i < num_groups_; ++i) InternalAt(next_stub_, i) = data_->GroupStartState(GroupId(pred, i)); arcs->push_back(A(0, pred, Weight::One(), FindState(next_stub_))); } } } else if (ilabel != 0) { Label pred = Prediction(state_stub_); Weight weight = Weight::One(); Prediction(next_stub_) = pred; for (int i = 0; i < num_groups_; ++i) InternalAt(next_stub_, i) = data_->GroupTransition( GroupId(pred, i), InternalAt(state_stub_, i), ilabel, pred, &weight); arcs->push_back(A(ilabel, 0, weight, FindState(next_stub_))); } } template inline LinearClassifierFstImpl *LinearClassifierFstImpl::Read( std::istream &strm, const FstReadOptions &opts) { std::unique_ptr> impl( new LinearClassifierFstImpl()); FstHeader header; if (!impl->ReadHeader(strm, opts, kMinFileVersion, &header)) { return nullptr; } impl->data_ = std::shared_ptr>(LinearFstData::Read(strm)); if (!impl->data_) { return nullptr; } ReadType(strm, &impl->num_classes_); if (!strm) { return nullptr; } impl->num_groups_ = impl->data_->NumGroups() / impl->num_classes_; if (impl->num_groups_ * impl->num_classes_ != impl->data_->NumGroups()) { FSTERROR() << "Total number of feature groups is not a multiple of the " "number of classes: num groups = " << impl->data_->NumGroups() << ", num classes = " << impl->num_classes_; return nullptr; } impl->ReserveStubSpace(); return impl.release(); } } // namespace internal // This class attaches interface to implementation and handles // reference counting, delegating most methods to ImplToFst. template class LinearClassifierFst : public ImplToFst> { public: friend class ArcIterator>; friend class StateIterator>; friend class LinearFstMatcherTpl>; typedef A Arc; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef DefaultCacheStore Store; typedef typename Store::State State; using Impl = internal::LinearClassifierFstImpl; LinearClassifierFst() : ImplToFst(std::make_shared()) {} explicit LinearClassifierFst(LinearFstData *data, size_t num_classes, const SymbolTable *isyms = nullptr, const SymbolTable *osyms = nullptr, CacheOptions opts = CacheOptions()) : ImplToFst( std::make_shared(data, num_classes, isyms, osyms, opts)) {} explicit LinearClassifierFst(const Fst &fst) : ImplToFst(std::make_shared()) { LOG(FATAL) << "LinearClassifierFst: no constructor from arbitrary FST."; } // See Fst<>::Copy() for doc. LinearClassifierFst(const LinearClassifierFst &fst, bool safe = false) : ImplToFst(fst, safe) {} // Get a copy of this LinearClassifierFst. See Fst<>::Copy() for further doc. LinearClassifierFst *Copy(bool safe = false) const override { return new LinearClassifierFst(*this, safe); } inline void InitStateIterator(StateIteratorData *data) const override; void InitArcIterator(StateId s, ArcIteratorData *data) const override { GetMutableImpl()->InitArcIterator(s, data); } MatcherBase *InitMatcher(MatchType match_type) const override { return new LinearFstMatcherTpl>(this, match_type); } static LinearClassifierFst *Read(const string &filename) { if (!filename.empty()) { std::ifstream strm(filename, std::ios_base::in | std::ios_base::binary); if (!strm) { LOG(ERROR) << "LinearClassifierFst::Read: Can't open file: " << filename; return nullptr; } return Read(strm, FstReadOptions(filename)); } else { return Read(std::cin, FstReadOptions("standard input")); } } static LinearClassifierFst *Read(std::istream &in, const FstReadOptions &opts) { auto *impl = Impl::Read(in, opts); return impl ? new LinearClassifierFst(std::shared_ptr(impl)) : nullptr; } bool Write(const string &filename) const override { if (!filename.empty()) { std::ofstream strm(filename, std::ios_base::out | std::ios_base::binary); if (!strm) { LOG(ERROR) << "ProdLmFst::Write: Can't open file: " << filename; return false; } return Write(strm, FstWriteOptions(filename)); } else { return Write(std::cout, FstWriteOptions("standard output")); } } bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { return GetImpl()->Write(strm, opts); } private: using ImplToFst::GetImpl; using ImplToFst::GetMutableImpl; explicit LinearClassifierFst(std::shared_ptr impl) : ImplToFst(impl) {} void operator=(const LinearClassifierFst &fst) = delete; }; // Specialization for LinearClassifierFst. template class StateIterator> : public CacheStateIterator> { public: explicit StateIterator(const LinearClassifierFst &fst) : CacheStateIterator>(fst, fst.GetMutableImpl()) {} }; // Specialization for LinearClassifierFst. template class ArcIterator> : public CacheArcIterator> { public: using StateId = typename Arc::StateId; ArcIterator(const LinearClassifierFst &fst, StateId s) : CacheArcIterator>(fst.GetMutableImpl(), s) { if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); } }; template inline void LinearClassifierFst::InitStateIterator( StateIteratorData *data) const { data->base = new StateIterator>(*this); } // Specialized Matcher for LinearFsts. This matcher only supports // matching from the input side. This is intentional because comparing // the scores of different input sequences with the same output // sequence is meaningless in a discriminative model. template class LinearFstMatcherTpl : public MatcherBase { public: typedef typename F::Arc Arc; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; typedef typename Arc::StateId StateId; typedef F FST; // This makes a copy of the FST. LinearFstMatcherTpl(const FST &fst, MatchType match_type) : owned_fst_(fst.Copy()), fst_(*owned_fst_), match_type_(match_type), s_(kNoStateId), current_loop_(false), loop_(kNoLabel, 0, Weight::One(), kNoStateId), cur_arc_(0), error_(false) { switch (match_type_) { case MATCH_INPUT: case MATCH_OUTPUT: case MATCH_NONE: break; default: FSTERROR() << "LinearFstMatcherTpl: Bad match type"; match_type_ = MATCH_NONE; error_ = true; } } // This doesn't copy the FST. LinearFstMatcherTpl(const FST *fst, MatchType match_type) : fst_(*fst), match_type_(match_type), s_(kNoStateId), current_loop_(false), loop_(kNoLabel, 0, Weight::One(), kNoStateId), cur_arc_(0), error_(false) { switch (match_type_) { case MATCH_INPUT: case MATCH_OUTPUT: case MATCH_NONE: break; default: FSTERROR() << "LinearFstMatcherTpl: Bad match type"; match_type_ = MATCH_NONE; error_ = true; } } // This makes a copy of the FST. LinearFstMatcherTpl(const LinearFstMatcherTpl &matcher, bool safe = false) : owned_fst_(matcher.fst_.Copy(safe)), fst_(*owned_fst_), match_type_(matcher.match_type_), s_(kNoStateId), current_loop_(false), loop_(matcher.loop_), cur_arc_(0), error_(matcher.error_) {} LinearFstMatcherTpl *Copy(bool safe = false) const override { return new LinearFstMatcherTpl(*this, safe); } MatchType Type(bool /*test*/) const override { // `MATCH_INPUT` is the only valid type return match_type_ == MATCH_INPUT ? match_type_ : MATCH_NONE; } void SetState(StateId s) final { if (s_ == s) return; s_ = s; // `MATCH_INPUT` is the only valid type if (match_type_ != MATCH_INPUT) { FSTERROR() << "LinearFstMatcherTpl: Bad match type"; error_ = true; } loop_.nextstate = s; } bool Find(Label label) final { if (error_) { current_loop_ = false; return false; } current_loop_ = label == 0; if (label == kNoLabel) label = 0; arcs_.clear(); cur_arc_ = 0; fst_.GetMutableImpl()->MatchInput(s_, label, &arcs_); return current_loop_ || !arcs_.empty(); } bool Done() const final { return !(current_loop_ || cur_arc_ < arcs_.size()); } const Arc &Value() const final { return current_loop_ ? loop_ : arcs_[cur_arc_]; } void Next() final { if (current_loop_) current_loop_ = false; else ++cur_arc_; } ssize_t Priority(StateId s) final { return kRequirePriority; } const FST &GetFst() const override { return fst_; } uint64 Properties(uint64 props) const override { if (error_) props |= kError; return props; } uint32 Flags() const override { return kRequireMatch; } private: std::unique_ptr owned_fst_; const FST &fst_; MatchType match_type_; // Type of match to perform. StateId s_; // Current state. bool current_loop_; // Current arc is the implicit loop. Arc loop_; // For non-consuming symbols. // All out-going arcs matching the label in last Find() call. std::vector arcs_; size_t cur_arc_; // Index to the arc that `Value()` should return. bool error_; // Error encountered. }; } // namespace fst #endif // FST_EXTENSIONS_LINEAR_LINEAR_FST_H_