// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Simple concrete, mutable FST whose states and arcs are stored in STL vectors. #ifndef FST_VECTOR_FST_H_ #define FST_VECTOR_FST_H_ #include #include #include #include #include // For optional argument declarations #include #include namespace fst { template class VectorFst; template void Cast(const F &, G *); // Arcs (of type A) implemented by an STL vector per state. M specifies Arc // allocator (default declared in fst-decl.h). template */> class VectorState { public: using Arc = A; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; using ArcAllocator = M; using StateAllocator = typename ArcAllocator::template rebind>::other; // Provide STL allocator for arcs. explicit VectorState(const ArcAllocator &alloc) : final_(Weight::Zero()), niepsilons_(0), noepsilons_(0), arcs_(alloc) {} VectorState(const VectorState &state, const ArcAllocator &alloc) : final_(state.Final()), niepsilons_(state.NumInputEpsilons()), noepsilons_(state.NumOutputEpsilons()), arcs_(state.arcs_.begin(), state.arcs_.end(), alloc) {} void Reset() { final_ = Weight::Zero(); niepsilons_ = 0; noepsilons_ = 0; arcs_.clear(); } Weight Final() const { return final_; } size_t NumInputEpsilons() const { return niepsilons_; } size_t NumOutputEpsilons() const { return noepsilons_; } size_t NumArcs() const { return arcs_.size(); } const Arc &GetArc(size_t n) const { return arcs_[n]; } const Arc *Arcs() const { return !arcs_.empty() ? &arcs_[0] : nullptr; } Arc *MutableArcs() { return !arcs_.empty() ? &arcs_[0] : nullptr; } void ReserveArcs(size_t n) { arcs_.reserve(n); } void SetFinal(Weight weight) { final_ = std::move(weight); } void SetNumInputEpsilons(size_t n) { niepsilons_ = n; } void SetNumOutputEpsilons(size_t n) { noepsilons_ = n; } void AddArc(const Arc &arc) { if (arc.ilabel == 0) ++niepsilons_; if (arc.olabel == 0) ++noepsilons_; arcs_.push_back(arc); } void SetArc(const Arc &arc, size_t n) { if (arcs_[n].ilabel == 0) --niepsilons_; if (arcs_[n].olabel == 0) --noepsilons_; if (arc.ilabel == 0) ++niepsilons_; if (arc.olabel == 0) ++noepsilons_; arcs_[n] = arc; } void DeleteArcs() { niepsilons_ = 0; noepsilons_ = 0; arcs_.clear(); } void DeleteArcs(size_t n) { for (size_t i = 0; i < n; ++i) { if (arcs_.back().ilabel == 0) --niepsilons_; if (arcs_.back().olabel == 0) --noepsilons_; arcs_.pop_back(); } } // For state class allocation. void *operator new(size_t size, StateAllocator *alloc) { return alloc->allocate(1); } // For state destruction and memory freeing. static void Destroy(VectorState *state, StateAllocator *alloc) { if (state) { state->~VectorState(); alloc->deallocate(state, 1); } } private: Weight final_; // Final weight. size_t niepsilons_; // # of input epsilons size_t noepsilons_; // # of output epsilons std::vector arcs_; // Arc container. }; namespace internal { // States are implemented by STL vectors, templated on the // State definition. This does not manage the Fst properties. template class VectorFstBaseImpl : public FstImpl { public: using State = S; using Arc = typename State::Arc; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; VectorFstBaseImpl() : start_(kNoStateId) {} ~VectorFstBaseImpl() override { for (StateId s = 0; s < states_.size(); ++s) { State::Destroy(states_[s], &state_alloc_); } } StateId Start() const { return start_; } Weight Final(StateId state) const { return states_[state]->Final(); } StateId NumStates() const { return states_.size(); } size_t NumArcs(StateId state) const { return states_[state]->NumArcs(); } size_t NumInputEpsilons(StateId state) const { return GetState(state)->NumInputEpsilons(); } size_t NumOutputEpsilons(StateId state) const { return GetState(state)->NumOutputEpsilons(); } void SetStart(StateId state) { start_ = state; } void SetFinal(StateId state, Weight weight) { states_[state]->SetFinal(std::move(weight)); } StateId AddState() { states_.push_back(new (&state_alloc_) State(arc_alloc_)); return states_.size() - 1; } StateId AddState(State *state) { states_.push_back(state); return states_.size() - 1; } void AddArc(StateId state, const Arc &arc) { states_[state]->AddArc(arc); } void DeleteStates(const std::vector &dstates) { std::vector newid(states_.size(), 0); for (StateId i = 0; i < dstates.size(); ++i) newid[dstates[i]] = kNoStateId; StateId nstates = 0; for (StateId state = 0; state < states_.size(); ++state) { if (newid[state] != kNoStateId) { newid[state] = nstates; if (state != nstates) states_[nstates] = states_[state]; ++nstates; } else { State::Destroy(states_[state], &state_alloc_); } } states_.resize(nstates); for (StateId state = 0; state < states_.size(); ++state) { auto *arcs = states_[state]->MutableArcs(); size_t narcs = 0; auto nieps = states_[state]->NumInputEpsilons(); auto noeps = states_[state]->NumOutputEpsilons(); for (size_t i = 0; i < states_[state]->NumArcs(); ++i) { const auto t = newid[arcs[i].nextstate]; if (t != kNoStateId) { arcs[i].nextstate = t; if (i != narcs) arcs[narcs] = arcs[i]; ++narcs; } else { if (arcs[i].ilabel == 0) --nieps; if (arcs[i].olabel == 0) --noeps; } } states_[state]->DeleteArcs(states_[state]->NumArcs() - narcs); states_[state]->SetNumInputEpsilons(nieps); states_[state]->SetNumOutputEpsilons(noeps); } if (Start() != kNoStateId) SetStart(newid[Start()]); } void DeleteStates() { for (StateId state = 0; state < states_.size(); ++state) { State::Destroy(states_[state], &state_alloc_); } states_.clear(); SetStart(kNoStateId); } void DeleteArcs(StateId state, size_t n) { states_[state]->DeleteArcs(n); } void DeleteArcs(StateId state) { states_[state]->DeleteArcs(); } State *GetState(StateId state) { return states_[state]; } const State *GetState(StateId state) const { return states_[state]; } void SetState(StateId state, State *vstate) { states_[state] = vstate; } void ReserveStates(StateId n) { states_.reserve(n); } void ReserveArcs(StateId state, size_t n) { states_[state]->ReserveArcs(n); } // Provide information needed for generic state iterator. void InitStateIterator(StateIteratorData *data) const { data->base = nullptr; data->nstates = states_.size(); } // Provide information needed for generic arc iterator. void InitArcIterator(StateId state, ArcIteratorData *data) const { data->base = nullptr; data->narcs = states_[state]->NumArcs(); data->arcs = states_[state]->Arcs(); data->ref_count = nullptr; } private: std::vector states_; // States represenation. StateId start_; // Initial state. typename State::StateAllocator state_alloc_; // For state allocation. typename State::ArcAllocator arc_alloc_; // For arc allocation. VectorFstBaseImpl(const VectorFstBaseImpl &) = delete; VectorFstBaseImpl &operator=(const VectorFstBaseImpl &) = delete; }; // This is a VectorFstBaseImpl container that holds VectorStates and manages FST // properties. template class VectorFstImpl : public VectorFstBaseImpl { public: using State = S; using Arc = typename State::Arc; using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; using FstImpl::SetInputSymbols; using FstImpl::SetOutputSymbols; using FstImpl::SetType; using FstImpl::SetProperties; using FstImpl::Properties; using VectorFstBaseImpl::Start; using VectorFstBaseImpl::NumStates; using VectorFstBaseImpl::GetState; using VectorFstBaseImpl::ReserveArcs; friend class MutableArcIterator>; using BaseImpl = VectorFstBaseImpl; VectorFstImpl() { SetType("vector"); SetProperties(kNullProperties | kStaticProperties); } explicit VectorFstImpl(const Fst &fst); static VectorFstImpl *Read(std::istream &strm, const FstReadOptions &opts); void SetStart(StateId state) { BaseImpl::SetStart(state); SetProperties(SetStartProperties(Properties())); } void SetFinal(StateId state, Weight weight) { const auto old_weight = BaseImpl::Final(state); const auto properties = SetFinalProperties(Properties(), old_weight, weight); BaseImpl::SetFinal(state, std::move(weight)); SetProperties(properties); } StateId AddState() { const auto state = BaseImpl::AddState(); SetProperties(AddStateProperties(Properties())); return state; } void AddArc(StateId state, const Arc &arc) { auto *vstate = GetState(state); const auto *parc = vstate->NumArcs() == 0 ? nullptr : &(vstate->GetArc(vstate->NumArcs() - 1)); SetProperties(AddArcProperties(Properties(), state, arc, parc)); BaseImpl::AddArc(state, arc); } void DeleteStates(const std::vector &dstates) { BaseImpl::DeleteStates(dstates); SetProperties(DeleteStatesProperties(Properties())); } void DeleteStates() { BaseImpl::DeleteStates(); SetProperties(DeleteAllStatesProperties(Properties(), kStaticProperties)); } void DeleteArcs(StateId state, size_t n) { BaseImpl::DeleteArcs(state, n); SetProperties(DeleteArcsProperties(Properties())); } void DeleteArcs(StateId state) { BaseImpl::DeleteArcs(state); SetProperties(DeleteArcsProperties(Properties())); } // Properties always true of this FST class static constexpr uint64 kStaticProperties = kExpanded | kMutable; private: // Minimum file format version supported. static constexpr int kMinFileVersion = 2; }; template constexpr uint64 VectorFstImpl::kStaticProperties; template constexpr int VectorFstImpl::kMinFileVersion; template VectorFstImpl::VectorFstImpl(const Fst &fst) { SetType("vector"); SetInputSymbols(fst.InputSymbols()); SetOutputSymbols(fst.OutputSymbols()); BaseImpl::SetStart(fst.Start()); if (fst.Properties(kExpanded, false)) { BaseImpl::ReserveStates(CountStates(fst)); } for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { const auto state = siter.Value(); BaseImpl::AddState(); BaseImpl::SetFinal(state, fst.Final(state)); ReserveArcs(state, fst.NumArcs(state)); for (ArcIterator> aiter(fst, state); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); BaseImpl::AddArc(state, arc); } } SetProperties(fst.Properties(kCopyProperties, false) | kStaticProperties); } template VectorFstImpl *VectorFstImpl::Read(std::istream &strm, const FstReadOptions &opts) { std::unique_ptr> impl(new VectorFstImpl()); FstHeader hdr; if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return nullptr; impl->BaseImpl::SetStart(hdr.Start()); if (hdr.NumStates() != kNoStateId) impl->ReserveStates(hdr.NumStates()); StateId state = 0; for (; hdr.NumStates() == kNoStateId || state < hdr.NumStates(); ++state) { Weight weight; if (!weight.Read(strm)) break; impl->BaseImpl::AddState(); auto *vstate = impl->GetState(state); vstate->SetFinal(weight); int64 narcs; ReadType(strm, &narcs); if (!strm) { LOG(ERROR) << "VectorFst::Read: Read failed: " << opts.source; return nullptr; } impl->ReserveArcs(state, narcs); for (int64 i = 0; i < narcs; ++i) { Arc arc; ReadType(strm, &arc.ilabel); ReadType(strm, &arc.olabel); arc.weight.Read(strm); ReadType(strm, &arc.nextstate); if (!strm) { LOG(ERROR) << "VectorFst::Read: Read failed: " << opts.source; return nullptr; } impl->BaseImpl::AddArc(state, arc); } } if (hdr.NumStates() != kNoStateId && state != hdr.NumStates()) { LOG(ERROR) << "VectorFst::Read: Unexpected end of file: " << opts.source; return nullptr; } return impl.release(); } } // namespace internal // Simple concrete, mutable FST. This class attaches interface to implementation // and handles reference counting, delegating most methods to ImplToMutableFst. // Also supports ReserveStates and ReserveArcs methods (cf. STL vector methods). // The second optional template argument gives the State definition. template */> class VectorFst : public ImplToMutableFst> { public: using Arc = A; using StateId = typename Arc::StateId; using State = S; using Impl = internal::VectorFstImpl; friend class StateIterator>; friend class ArcIterator>; friend class MutableArcIterator>; template friend void Cast(const F &, G *); VectorFst() : ImplToMutableFst(std::make_shared()) {} explicit VectorFst(const Fst &fst) : ImplToMutableFst(std::make_shared(fst)) {} VectorFst(const VectorFst &fst, bool safe = false) : ImplToMutableFst(fst) {} // Get a copy of this VectorFst. See Fst<>::Copy() for further doc. VectorFst *Copy(bool safe = false) const override { return new VectorFst(*this, safe); } VectorFst &operator=(const VectorFst &fst) { SetImpl(fst.GetSharedImpl()); return *this; } VectorFst &operator=(const Fst &fst) override { if (this != &fst) SetImpl(std::make_shared(fst)); return *this; } // Reads a VectorFst from an input stream, returning nullptr on error. static VectorFst *Read(std::istream &strm, const FstReadOptions &opts) { auto *impl = Impl::Read(strm, opts); return impl ? new VectorFst(std::shared_ptr(impl)) : nullptr; } // Read a VectorFst from a file, returning nullptr on error; empty filename // reads from standard input. static VectorFst *Read(const string &filename) { auto *impl = ImplToExpandedFst>::Read(filename); return impl ? new VectorFst(std::shared_ptr(impl)) : nullptr; } bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { return WriteFst(*this, strm, opts); } bool Write(const string &filename) const override { return Fst::WriteFile(filename); } template static bool WriteFst(const FST &fst, std::ostream &strm, const FstWriteOptions &opts); void InitStateIterator(StateIteratorData *data) const override { GetImpl()->InitStateIterator(data); } void InitArcIterator(StateId s, ArcIteratorData *data) const override { GetImpl()->InitArcIterator(s, data); } inline void InitMutableArcIterator(StateId s, MutableArcIteratorData *) override; using ImplToMutableFst>::ReserveArcs; using ImplToMutableFst>::ReserveStates; private: using ImplToMutableFst>::GetImpl; using ImplToMutableFst>::MutateCheck; using ImplToMutableFst>::SetImpl; explicit VectorFst(std::shared_ptr impl) : ImplToMutableFst(impl) {} }; // Writes FST to file in Vector format, potentially with a pass over the machine // before writing to compute number of states. template template bool VectorFst::WriteFst(const FST &fst, std::ostream &strm, const FstWriteOptions &opts) { static constexpr int file_version = 2; bool update_header = true; FstHeader hdr; hdr.SetStart(fst.Start()); hdr.SetNumStates(kNoStateId); size_t start_offset = 0; if (fst.Properties(kExpanded, false) || opts.stream_write || (start_offset = strm.tellp()) != -1) { hdr.SetNumStates(CountStates(fst)); update_header = false; } const auto properties = fst.Properties(kCopyProperties, false) | Impl::kStaticProperties; internal::FstImpl::WriteFstHeader(fst, strm, opts, file_version, "vector", properties, &hdr); StateId num_states = 0; for (StateIterator siter(fst); !siter.Done(); siter.Next()) { const auto s = siter.Value(); fst.Final(s).Write(strm); const int64 narcs = fst.NumArcs(s); WriteType(strm, narcs); for (ArcIterator aiter(fst, s); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); WriteType(strm, arc.ilabel); WriteType(strm, arc.olabel); arc.weight.Write(strm); WriteType(strm, arc.nextstate); } ++num_states; } strm.flush(); if (!strm) { LOG(ERROR) << "VectorFst::Write: Write failed: " << opts.source; return false; } if (update_header) { hdr.SetNumStates(num_states); return internal::FstImpl::UpdateFstHeader( fst, strm, opts, file_version, "vector", properties, &hdr, start_offset); } else { if (num_states != hdr.NumStates()) { LOG(ERROR) << "Inconsistent number of states observed during write"; return false; } } return true; } // Specialization for VectorFst; see generic version in fst.h for sample usage // (but use the VectorFst type instead). This version should inline. template class StateIterator> { public: using StateId = typename Arc::StateId; explicit StateIterator(const VectorFst &fst) : nstates_(fst.GetImpl()->NumStates()), s_(0) {} bool Done() const { return s_ >= nstates_; } StateId Value() const { return s_; } void Next() { ++s_; } void Reset() { s_ = 0; } private: const StateId nstates_; StateId s_; }; // Specialization for VectorFst; see generic version in fst.h for sample usage // (but use the VectorFst type instead). This version should inline. template class ArcIterator> { public: using StateId = typename Arc::StateId; ArcIterator(const VectorFst &fst, StateId s) : arcs_(fst.GetImpl()->GetState(s)->Arcs()), narcs_(fst.GetImpl()->GetState(s)->NumArcs()), i_(0) {} bool Done() const { return i_ >= narcs_; } const Arc &Value() const { return arcs_[i_]; } void Next() { ++i_; } void Reset() { i_ = 0; } void Seek(size_t a) { i_ = a; } size_t Position() const { return i_; } constexpr uint32 Flags() const { return kArcValueFlags; } void SetFlags(uint32, uint32) {} private: const Arc *arcs_; size_t narcs_; size_t i_; }; // Specialization for VectorFst; see generic version in mutable-fst.h for sample // usage (but use the VectorFst type instead). This version should inline. template class MutableArcIterator> : public MutableArcIteratorBase { public: using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; MutableArcIterator(VectorFst *fst, StateId s) : i_(0) { fst->MutateCheck(); state_ = fst->GetMutableImpl()->GetState(s); properties_ = &fst->GetImpl()->properties_; } bool Done() const final { return i_ >= state_->NumArcs(); } const Arc &Value() const final { return state_->GetArc(i_); } void Next() final { ++i_; } size_t Position() const final { return i_; } void Reset() final { i_ = 0; } void Seek(size_t a) final { i_ = a; } void SetValue(const Arc &arc) final { const auto &oarc = state_->GetArc(i_); if (oarc.ilabel != oarc.olabel) *properties_ &= ~kNotAcceptor; if (oarc.ilabel == 0) { *properties_ &= ~kIEpsilons; if (oarc.olabel == 0) *properties_ &= ~kEpsilons; } if (oarc.olabel == 0) *properties_ &= ~kOEpsilons; if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One()) { *properties_ &= ~kWeighted; } state_->SetArc(arc, i_); if (arc.ilabel != arc.olabel) { *properties_ |= kNotAcceptor; *properties_ &= ~kAcceptor; } if (arc.ilabel == 0) { *properties_ |= kIEpsilons; *properties_ &= ~kNoIEpsilons; if (arc.olabel == 0) { *properties_ |= kEpsilons; *properties_ &= ~kNoEpsilons; } } if (arc.olabel == 0) { *properties_ |= kOEpsilons; *properties_ &= ~kNoOEpsilons; } if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) { *properties_ |= kWeighted; *properties_ &= ~kUnweighted; } *properties_ &= kSetArcProperties | kAcceptor | kNotAcceptor | kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | kWeighted | kUnweighted; } uint32 Flags() const final { return kArcValueFlags; } void SetFlags(uint32, uint32) final {} private: State *state_; uint64 *properties_; size_t i_; }; // Provides information needed for the generic mutable arc iterator. template inline void VectorFst::InitMutableArcIterator( StateId s, MutableArcIteratorData *data) { data->base = new MutableArcIterator>(this, s); } // A useful alias when using StdArc. using StdVectorFst = VectorFst; } // namespace fst #endif // FST_VECTOR_FST_H_