// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Functions and classes for the recursive replacement of FSTs. #ifndef FST_REPLACE_H_ #define FST_REPLACE_H_ #include #include #include #include #include #include #include #include #include // For optional argument declarations. #include #include #include #include #include namespace fst { // Replace state tables have the form: // // template // class ReplaceStateTable { // public: // using Label = typename Arc::Label Label; // using StateId = typename Arc::StateId; // // using PrefixId = P; // using StateTuple = ReplaceStateTuple; // using StackPrefix = ReplaceStackPrefix; // // // Required constructor. // ReplaceStateTable( // const std::vector *>> &fst_list, // Label root); // // // Required copy constructor that does not copy state. // ReplaceStateTable(const ReplaceStateTable &table); // // // Looks up state ID by tuple, adding it if it doesn't exist. // StateId FindState(const StateTuple &tuple); // // // Looks up state tuple by ID. // const StateTuple &Tuple(StateId id) const; // // // Lookus up prefix ID by stack prefix, adding it if it doesn't exist. // PrefixId FindPrefixId(const StackPrefix &stack_prefix); // // // Looks up stack prefix by ID. // const StackPrefix &GetStackPrefix(PrefixId id) const; // }; // Tuple that uniquely defines a state in replace. template struct ReplaceStateTuple { using StateId = S; using PrefixId = P; ReplaceStateTuple(PrefixId prefix_id = -1, StateId fst_id = kNoStateId, StateId fst_state = kNoStateId) : prefix_id(prefix_id), fst_id(fst_id), fst_state(fst_state) {} PrefixId prefix_id; // Index in prefix table. StateId fst_id; // Current FST being walked. StateId fst_state; // Current state in FST being walked (not to be // confused with the thse StateId of the combined FST). }; // Equality of replace state tuples. template inline bool operator==(const ReplaceStateTuple &x, const ReplaceStateTuple &y) { return x.prefix_id == y.prefix_id && x.fst_id == y.fst_id && x.fst_state == y.fst_state; } // Functor returning true for tuples corresponding to states in the root FST. template class ReplaceRootSelector { public: bool operator()(const ReplaceStateTuple &tuple) const { return tuple.prefix_id == 0; } }; // Functor for fingerprinting replace state tuples. template class ReplaceFingerprint { public: explicit ReplaceFingerprint(const std::vector *size_array) : size_array_(size_array) {} uint64 operator()(const ReplaceStateTuple &tuple) const { return tuple.prefix_id * size_array_->back() + size_array_->at(tuple.fst_id - 1) + tuple.fst_state; } private: const std::vector *size_array_; }; // Useful when the fst_state uniquely define the tuple. template class ReplaceFstStateFingerprint { public: uint64 operator()(const ReplaceStateTuple &tuple) const { return tuple.fst_state; } }; // A generic hash function for replace state tuples. template class ReplaceHash { public: size_t operator()(const ReplaceStateTuple& t) const { static constexpr size_t prime0 = 7853; static constexpr size_t prime1 = 7867; return t.prefix_id + t.fst_id * prime0 + t.fst_state * prime1; } }; // Container for stack prefix. template class ReplaceStackPrefix { public: struct PrefixTuple { PrefixTuple(Label fst_id = kNoLabel, StateId nextstate = kNoStateId) : fst_id(fst_id), nextstate(nextstate) {} Label fst_id; StateId nextstate; }; ReplaceStackPrefix() {} ReplaceStackPrefix(const ReplaceStackPrefix &other) : prefix_(other.prefix_) {} void Push(StateId fst_id, StateId nextstate) { prefix_.push_back(PrefixTuple(fst_id, nextstate)); } void Pop() { prefix_.pop_back(); } const PrefixTuple &Top() const { return prefix_[prefix_.size() - 1]; } size_t Depth() const { return prefix_.size(); } public: std::vector prefix_; }; // Equality stack prefix classes. template inline bool operator==(const ReplaceStackPrefix &x, const ReplaceStackPrefix &y) { if (x.prefix_.size() != y.prefix_.size()) return false; for (size_t i = 0; i < x.prefix_.size(); ++i) { if (x.prefix_[i].fst_id != y.prefix_[i].fst_id || x.prefix_[i].nextstate != y.prefix_[i].nextstate) { return false; } } return true; } // Hash function for stack prefix to prefix id. template class ReplaceStackPrefixHash { public: size_t operator()(const ReplaceStackPrefix &prefix) const { size_t sum = 0; for (const auto &pair : prefix.prefix_) { static constexpr size_t prime = 7863; sum += pair.fst_id + pair.nextstate * prime; } return sum; } }; // Replace state tables. // A two-level state table for replace. Warning: calls CountStates to compute // the number of states of each component FST. template class VectorHashReplaceStateTable { public: using Label = typename Arc::Label; using StateId = typename Arc::StateId; using PrefixId = P; using StateTuple = ReplaceStateTuple; using StateTable = VectorHashStateTable, ReplaceRootSelector, ReplaceFstStateFingerprint, ReplaceFingerprint>; using StackPrefix = ReplaceStackPrefix; using StackPrefixTable = CompactHashBiTable>; VectorHashReplaceStateTable( const std::vector *>> &fst_list, Label root) : root_size_(0) { size_array_.push_back(0); for (const auto &fst_pair : fst_list) { if (fst_pair.first == root) { root_size_ = CountStates(*(fst_pair.second)); size_array_.push_back(size_array_.back()); } else { size_array_.push_back(size_array_.back() + CountStates(*(fst_pair.second))); } } state_table_.reset( new StateTable(new ReplaceRootSelector, new ReplaceFstStateFingerprint, new ReplaceFingerprint(&size_array_), root_size_, root_size_ + size_array_.back())); } VectorHashReplaceStateTable( const VectorHashReplaceStateTable &table) : root_size_(table.root_size_), size_array_(table.size_array_), prefix_table_(table.prefix_table_) { state_table_.reset( new StateTable(new ReplaceRootSelector, new ReplaceFstStateFingerprint, new ReplaceFingerprint(&size_array_), root_size_, root_size_ + size_array_.back())); } StateId FindState(const StateTuple &tuple) { return state_table_->FindState(tuple); } const StateTuple &Tuple(StateId id) const { return state_table_->Tuple(id); } PrefixId FindPrefixId(const StackPrefix &prefix) { return prefix_table_.FindId(prefix); } const StackPrefix& GetStackPrefix(PrefixId id) const { return prefix_table_.FindEntry(id); } private: StateId root_size_; std::vector size_array_; std::unique_ptr state_table_; StackPrefixTable prefix_table_; }; // Default replace state table. template class DefaultReplaceStateTable : public CompactHashStateTable, ReplaceHash> { public: using Label = typename Arc::Label; using StateId = typename Arc::StateId; using PrefixId = P; using StateTuple = ReplaceStateTuple; using StateTable = CompactHashStateTable>; using StackPrefix = ReplaceStackPrefix; using StackPrefixTable = CompactHashBiTable>; using StateTable::FindState; using StateTable::Tuple; DefaultReplaceStateTable( const std::vector *>> &, Label) {} DefaultReplaceStateTable(const DefaultReplaceStateTable &table) : StateTable(), prefix_table_(table.prefix_table_) {} PrefixId FindPrefixId(const StackPrefix &prefix) { return prefix_table_.FindId(prefix); } const StackPrefix &GetStackPrefix(PrefixId id) const { return prefix_table_.FindEntry(id); } private: StackPrefixTable prefix_table_; }; // By default ReplaceFst will copy the input label of the replace arc. // The call_label_type and return_label_type options specify how to manage // the labels of the call arc and the return arc of the replace FST template , class CacheStore = DefaultCacheStore> struct ReplaceFstOptions : CacheImplOptions { using Label = typename Arc::Label; // Index of root rule for expansion. Label root; // How to label call arc. ReplaceLabelType call_label_type = REPLACE_LABEL_INPUT; // How to label return arc. ReplaceLabelType return_label_type = REPLACE_LABEL_NEITHER; // Specifies output label to put on call arc; if kNoLabel, use existing label // on call arc. Otherwise, use this field as the output label. Label call_output_label = kNoLabel; // Specifies label to put on return arc. Label return_label = 0; // Take ownership of input FSTs? bool take_ownership = false; // Pointer to optional pre-constructed state table. StateTable *state_table = nullptr; explicit ReplaceFstOptions(const CacheImplOptions &opts, Label root = kNoLabel) : CacheImplOptions(opts), root(root) {} explicit ReplaceFstOptions(const CacheOptions &opts, Label root = kNoLabel) : CacheImplOptions(opts), root(root) {} // FIXME(kbg): There are too many constructors here. Come up with a consistent // position for call_output_label (probably the very end) so that it is // possible to express all the remaining constructors with a single // default-argument constructor. Also move clients off of the "backwards // compatibility" constructor, for good. explicit ReplaceFstOptions(Label root) : root(root) {} explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label return_label) : root(root), call_label_type(call_label_type), return_label_type(return_label_type), return_label(return_label) {} explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label call_output_label, Label return_label) : root(root), call_label_type(call_label_type), return_label_type(return_label_type), call_output_label(call_output_label), return_label(return_label) {} explicit ReplaceFstOptions(const ReplaceUtilOptions &opts) : ReplaceFstOptions(opts.root, opts.call_label_type, opts.return_label_type, opts.return_label) {} ReplaceFstOptions() : root(kNoLabel) {} // For backwards compatibility. ReplaceFstOptions(int64 root, bool epsilon_replace_arc) : root(root), call_label_type(epsilon_replace_arc ? REPLACE_LABEL_NEITHER : REPLACE_LABEL_INPUT), call_output_label(epsilon_replace_arc ? 0 : kNoLabel) {} }; // Forward declaration. template class ReplaceFstMatcher; template using FstList = std::vector *>>; // Returns true if label type on arc results in epsilon input label. inline bool EpsilonOnInput(ReplaceLabelType label_type) { return label_type == REPLACE_LABEL_NEITHER || label_type == REPLACE_LABEL_OUTPUT; } // Returns true if label type on arc results in epsilon input label. inline bool EpsilonOnOutput(ReplaceLabelType label_type) { return label_type == REPLACE_LABEL_NEITHER || label_type == REPLACE_LABEL_INPUT; } // Returns true if for either the call or return arc ilabel != olabel. template bool ReplaceTransducer(ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label call_output_label) { return call_label_type == REPLACE_LABEL_INPUT || call_label_type == REPLACE_LABEL_OUTPUT || (call_label_type == REPLACE_LABEL_BOTH && call_output_label != kNoLabel) || return_label_type == REPLACE_LABEL_INPUT || return_label_type == REPLACE_LABEL_OUTPUT; } template uint64 ReplaceFstProperties(typename Arc::Label root_label, const FstList &fst_list, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, typename Arc::Label call_output_label, bool *sorted_and_non_empty) { using Label = typename Arc::Label; std::vector inprops; bool all_ilabel_sorted = true; bool all_olabel_sorted = true; bool all_non_empty = true; // All nonterminals are negative? bool all_negative = true; // All nonterminals are positive and form a dense range containing 1? bool dense_range = true; Label root_fst_idx = 0; for (Label i = 0; i < fst_list.size(); ++i) { const auto label = fst_list[i].first; if (label >= 0) all_negative = false; if (label > fst_list.size() || label <= 0) dense_range = false; if (label == root_label) root_fst_idx = i; const auto *fst = fst_list[i].second; if (fst->Start() == kNoStateId) all_non_empty = false; if (!fst->Properties(kILabelSorted, false)) all_ilabel_sorted = false; if (!fst->Properties(kOLabelSorted, false)) all_olabel_sorted = false; inprops.push_back(fst->Properties(kCopyProperties, false)); } const auto props = ReplaceProperties( inprops, root_fst_idx, EpsilonOnInput(call_label_type), EpsilonOnInput(return_label_type), EpsilonOnOutput(call_label_type), EpsilonOnOutput(return_label_type), ReplaceTransducer(call_label_type, return_label_type, call_output_label), all_non_empty, all_ilabel_sorted, all_olabel_sorted, all_negative || dense_range); const bool sorted = props & (kILabelSorted | kOLabelSorted); *sorted_and_non_empty = all_non_empty && sorted; return props; } namespace internal { // The replace implementation class supports a dynamic expansion of a recursive // transition network represented as label/FST pairs with dynamic replacable // arcs. template class ReplaceFstImpl : public CacheBaseImpl { public: using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; using State = typename CacheStore::State; using CacheImpl = CacheBaseImpl; using PrefixId = typename StateTable::PrefixId; using StateTuple = ReplaceStateTuple; using StackPrefix = ReplaceStackPrefix; using NonTerminalHash = std::unordered_map; using FstImpl::SetType; using FstImpl::SetProperties; using FstImpl::WriteHeader; using FstImpl::SetInputSymbols; using FstImpl::SetOutputSymbols; using FstImpl::InputSymbols; using FstImpl::OutputSymbols; using CacheImpl::PushArc; using CacheImpl::HasArcs; using CacheImpl::HasFinal; using CacheImpl::HasStart; using CacheImpl::SetArcs; using CacheImpl::SetFinal; using CacheImpl::SetStart; friend class ReplaceFstMatcher; ReplaceFstImpl(const FstList &fst_list, const ReplaceFstOptions &opts) : CacheImpl(opts), call_label_type_(opts.call_label_type), return_label_type_(opts.return_label_type), call_output_label_(opts.call_output_label), return_label_(opts.return_label), state_table_(opts.state_table ? opts.state_table : new StateTable(fst_list, opts.root)) { SetType("replace"); // If the label is epsilon, then all replace label options are equivalent, // so we set the label types to NEITHER for simplicity. if (call_output_label_ == 0) call_label_type_ = REPLACE_LABEL_NEITHER; if (return_label_ == 0) return_label_type_ = REPLACE_LABEL_NEITHER; if (!fst_list.empty()) { SetInputSymbols(fst_list[0].second->InputSymbols()); SetOutputSymbols(fst_list[0].second->OutputSymbols()); } fst_array_.push_back(nullptr); for (Label i = 0; i < fst_list.size(); ++i) { const auto label = fst_list[i].first; const auto *fst = fst_list[i].second; nonterminal_hash_[label] = fst_array_.size(); nonterminal_set_.insert(label); fst_array_.emplace_back(opts.take_ownership ? fst : fst->Copy()); if (i) { if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) { FSTERROR() << "ReplaceFstImpl: Input symbols of FST " << i << " do not match input symbols of base FST (0th FST)"; SetProperties(kError, kError); } if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) { FSTERROR() << "ReplaceFstImpl: Output symbols of FST " << i << " do not match output symbols of base FST (0th FST)"; SetProperties(kError, kError); } } } const auto nonterminal = nonterminal_hash_[opts.root]; if ((nonterminal == 0) && (fst_array_.size() > 1)) { FSTERROR() << "ReplaceFstImpl: No FST corresponding to root label " << opts.root << " in the input tuple vector"; SetProperties(kError, kError); } root_ = (nonterminal > 0) ? nonterminal : 1; bool all_non_empty_and_sorted = false; SetProperties(ReplaceFstProperties(opts.root, fst_list, call_label_type_, return_label_type_, call_output_label_, &all_non_empty_and_sorted)); // Enables optional caching as long as sorted and all non-empty. always_cache_ = !all_non_empty_and_sorted; VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = " << (always_cache_ ? "true" : "false"); } ReplaceFstImpl(const ReplaceFstImpl &impl) : CacheImpl(impl), call_label_type_(impl.call_label_type_), return_label_type_(impl.return_label_type_), call_output_label_(impl.call_output_label_), return_label_(impl.return_label_), always_cache_(impl.always_cache_), state_table_(new StateTable(*(impl.state_table_))), nonterminal_set_(impl.nonterminal_set_), nonterminal_hash_(impl.nonterminal_hash_), root_(impl.root_) { SetType("replace"); SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(impl.InputSymbols()); SetOutputSymbols(impl.OutputSymbols()); fst_array_.reserve(impl.fst_array_.size()); fst_array_.emplace_back(nullptr); for (Label i = 1; i < impl.fst_array_.size(); ++i) { fst_array_.emplace_back(impl.fst_array_[i]->Copy(true)); } } // Computes the dependency graph of the replace class and returns // true if the dependencies are cyclic. Cyclic dependencies will result // in an un-expandable FST. bool CyclicDependencies() const { const ReplaceUtilOptions opts(root_); ReplaceUtil replace_util(fst_array_, nonterminal_hash_, opts); return replace_util.CyclicDependencies(); } StateId Start() { if (!HasStart()) { if (fst_array_.size() == 1) { SetStart(kNoStateId); return kNoStateId; } else { const auto fst_start = fst_array_[root_]->Start(); if (fst_start == kNoStateId) return kNoStateId; const auto prefix = GetPrefixId(StackPrefix()); const auto start = state_table_->FindState(StateTuple(prefix, root_, fst_start)); SetStart(start); return start; } } else { return CacheImpl::Start(); } } Weight Final(StateId s) { if (HasFinal(s)) return CacheImpl::Final(s); const auto &tuple = state_table_->Tuple(s); auto weight = Weight::Zero(); if (tuple.prefix_id == 0) { const auto fst_state = tuple.fst_state; weight = fst_array_[tuple.fst_id]->Final(fst_state); } if (always_cache_ || HasArcs(s)) SetFinal(s, weight); return weight; } size_t NumArcs(StateId s) { if (HasArcs(s)) { return CacheImpl::NumArcs(s); } else if (always_cache_) { // If always caching, expands and caches state. Expand(s); return CacheImpl::NumArcs(s); } else { // Otherwise computes the number of arcs without expanding. const auto tuple = state_table_->Tuple(s); if (tuple.fst_state == kNoStateId) return 0; auto num_arcs = fst_array_[tuple.fst_id]->NumArcs(tuple.fst_state); if (ComputeFinalArc(tuple, nullptr)) ++num_arcs; return num_arcs; } } // Returns whether a given label is a non-terminal. bool IsNonTerminal(Label label) const { if (label < *nonterminal_set_.begin() || label > *nonterminal_set_.rbegin()) { return false; } else { return nonterminal_hash_.count(label); } // TODO(allauzen): be smarter and take advantage of all_dense or // all_negative. Also use this in ComputeArc. This would require changes to // Replace so that recursing into an empty FST lead to a non co-accessible // state instead of deleting the arc as done currently. The current use // correct, since labels are sorted if all_non_empty is true. } size_t NumInputEpsilons(StateId s) { if (HasArcs(s)) { return CacheImpl::NumInputEpsilons(s); } else if (always_cache_ || !Properties(kILabelSorted)) { // If always caching or if the number of input epsilons is too expensive // to compute without caching (i.e., not ilabel-sorted), then expands and // caches state. Expand(s); return CacheImpl::NumInputEpsilons(s); } else { // Otherwise, computes the number of input epsilons without caching. const auto tuple = state_table_->Tuple(s); if (tuple.fst_state == kNoStateId) return 0; size_t num = 0; if (!EpsilonOnInput(call_label_type_)) { // If EpsilonOnInput(c) is false, all input epsilon arcs // are also input epsilons arcs in the underlying machine. num = fst_array_[tuple.fst_id]->NumInputEpsilons(tuple.fst_state); } else { // Otherwise, one need to consider that all non-terminal arcs // in the underlying machine also become input epsilon arc. ArcIterator> aiter(*fst_array_[tuple.fst_id], tuple.fst_state); for (; !aiter.Done() && ((aiter.Value().ilabel == 0) || IsNonTerminal(aiter.Value().olabel)); aiter.Next()) { ++num; } } if (EpsilonOnInput(return_label_type_) && ComputeFinalArc(tuple, nullptr)) { ++num; } return num; } } size_t NumOutputEpsilons(StateId s) { if (HasArcs(s)) { return CacheImpl::NumOutputEpsilons(s); } else if (always_cache_ || !Properties(kOLabelSorted)) { // If always caching or if the number of output epsilons is too expensive // to compute without caching (i.e., not olabel-sorted), then expands and // caches state. Expand(s); return CacheImpl::NumOutputEpsilons(s); } else { // Otherwise, computes the number of output epsilons without caching. const auto tuple = state_table_->Tuple(s); if (tuple.fst_state == kNoStateId) return 0; size_t num = 0; if (!EpsilonOnOutput(call_label_type_)) { // If EpsilonOnOutput(c) is false, all output epsilon arcs are also // output epsilons arcs in the underlying machine. num = fst_array_[tuple.fst_id]->NumOutputEpsilons(tuple.fst_state); } else { // Otherwise, one need to consider that all non-terminal arcs in the // underlying machine also become output epsilon arc. ArcIterator> aiter(*fst_array_[tuple.fst_id], tuple.fst_state); for (; !aiter.Done() && ((aiter.Value().olabel == 0) || IsNonTerminal(aiter.Value().olabel)); aiter.Next()) { ++num; } } if (EpsilonOnOutput(return_label_type_) && ComputeFinalArc(tuple, nullptr)) { ++num; } return num; } } uint64 Properties() const override { return Properties(kFstProperties); } // Sets error if found, and returns other FST impl properties. uint64 Properties(uint64 mask) const override { if (mask & kError) { for (Label i = 1; i < fst_array_.size(); ++i) { if (fst_array_[i]->Properties(kError, false)) { SetProperties(kError, kError); } } } return FstImpl::Properties(mask); } // Returns the base arc iterator, and if arcs have not been computed yet, // extends and recurses for new arcs. void InitArcIterator(StateId s, ArcIteratorData *data) { if (!HasArcs(s)) Expand(s); CacheImpl::InitArcIterator(s, data); // TODO(allauzen): Set behaviour of generic iterator. // Warning: ArcIterator>::InitCache() relies on current // behaviour. } // Extends current state (walk arcs one level deep). void Expand(StateId s) { const auto tuple = state_table_->Tuple(s); if (tuple.fst_state == kNoStateId) { // Local FST is empty. SetArcs(s); return; } ArcIterator> aiter(*fst_array_[tuple.fst_id], tuple.fst_state); Arc arc; // Creates a final arc when needed. if (ComputeFinalArc(tuple, &arc)) PushArc(s, arc); // Expands all arcs leaving the state. for (; !aiter.Done(); aiter.Next()) { if (ComputeArc(tuple, aiter.Value(), &arc)) PushArc(s, arc); } SetArcs(s); } void Expand(StateId s, const StateTuple &tuple, const ArcIteratorData &data) { if (tuple.fst_state == kNoStateId) { // Local FST is empty. SetArcs(s); return; } ArcIterator> aiter(data); Arc arc; // Creates a final arc when needed. if (ComputeFinalArc(tuple, &arc)) AddArc(s, arc); // Expands all arcs leaving the state. for (; !aiter.Done(); aiter.Next()) { if (ComputeArc(tuple, aiter.Value(), &arc)) AddArc(s, arc); } SetArcs(s); } // If acpp is null, only returns true if a final arcp is required, but does // not actually compute it. bool ComputeFinalArc(const StateTuple &tuple, Arc *arcp, uint32 flags = kArcValueFlags) { const auto fst_state = tuple.fst_state; if (fst_state == kNoStateId) return false; // If state is final, pops the stack. if (fst_array_[tuple.fst_id]->Final(fst_state) != Weight::Zero() && tuple.prefix_id) { if (arcp) { arcp->ilabel = (EpsilonOnInput(return_label_type_)) ? 0 : return_label_; arcp->olabel = (EpsilonOnOutput(return_label_type_)) ? 0 : return_label_; if (flags & kArcNextStateValue) { const auto &stack = state_table_->GetStackPrefix(tuple.prefix_id); const auto prefix_id = PopPrefix(stack); const auto &top = stack.Top(); arcp->nextstate = state_table_->FindState( StateTuple(prefix_id, top.fst_id, top.nextstate)); } if (flags & kArcWeightValue) { arcp->weight = fst_array_[tuple.fst_id]->Final(fst_state); } } return true; } else { return false; } } // Computes an arc in the FST corresponding to one in the underlying machine. // Returns false if the underlying arc corresponds to no arc in the resulting // FST. bool ComputeArc(const StateTuple &tuple, const Arc &arc, Arc *arcp, uint32 flags = kArcValueFlags) { if (!EpsilonOnInput(call_label_type_) && (flags == (flags & (kArcILabelValue | kArcWeightValue)))) { *arcp = arc; return true; } if (arc.olabel == 0 || arc.olabel < *nonterminal_set_.begin() || arc.olabel > *nonterminal_set_.rbegin()) { // Expands local FST. const auto nextstate = flags & kArcNextStateValue ? state_table_->FindState( StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate)) : kNoStateId; *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate); } else { // Checks for non-terminal. const auto it = nonterminal_hash_.find(arc.olabel); if (it != nonterminal_hash_.end()) { // Recurses into non-terminal. const auto nonterminal = it->second; const auto nt_prefix = PushPrefix(state_table_->GetStackPrefix(tuple.prefix_id), tuple.fst_id, arc.nextstate); // If the start state is valid, replace; othewise, the arc is implicitly // deleted. const auto nt_start = fst_array_[nonterminal]->Start(); if (nt_start != kNoStateId) { const auto nt_nextstate = flags & kArcNextStateValue ? state_table_->FindState(StateTuple( nt_prefix, nonterminal, nt_start)) : kNoStateId; const auto ilabel = (EpsilonOnInput(call_label_type_)) ? 0 : arc.ilabel; const auto olabel = (EpsilonOnOutput(call_label_type_)) ? 0 : ((call_output_label_ == kNoLabel) ? arc.olabel : call_output_label_); *arcp = Arc(ilabel, olabel, arc.weight, nt_nextstate); } else { return false; } } else { const auto nextstate = flags & kArcNextStateValue ? state_table_->FindState( StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate)) : kNoStateId; *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate); } } return true; } // Returns the arc iterator flags supported by this FST. uint32 ArcIteratorFlags() const { uint32 flags = kArcValueFlags; if (!always_cache_) flags |= kArcNoCache; return flags; } StateTable *GetStateTable() const { return state_table_.get(); } const Fst *GetFst(Label fst_id) const { return fst_array_[fst_id].get(); } Label GetFstId(Label nonterminal) const { const auto it = nonterminal_hash_.find(nonterminal); if (it == nonterminal_hash_.end()) { FSTERROR() << "ReplaceFstImpl::GetFstId: Nonterminal not found: " << nonterminal; } return it->second; } // Returns true if label type on call arc results in epsilon input label. bool EpsilonOnCallInput() { return EpsilonOnInput(call_label_type_); } private: // The unique index into stack prefix table. PrefixId GetPrefixId(const StackPrefix &prefix) { return state_table_->FindPrefixId(prefix); } // The prefix ID after a stack pop. PrefixId PopPrefix(StackPrefix prefix) { prefix.Pop(); return GetPrefixId(prefix); } // The prefix ID after a stack push. PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) { prefix.Push(fst_id, nextstate); return GetPrefixId(prefix); } // Runtime options ReplaceLabelType call_label_type_; // How to label call arc. ReplaceLabelType return_label_type_; // How to label return arc. int64 call_output_label_; // Specifies output label to put on call arc int64 return_label_; // Specifies label to put on return arc. bool always_cache_; // Disable optional caching of arc iterator? // State table. std::unique_ptr state_table_; // Replace components. std::set