// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Functions and classes for various FST state queues with a unified interface. #ifndef FST_QUEUE_H_ #define FST_QUEUE_H_ #include #include #include #include #include #include #include #include #include #include namespace fst { // The Queue interface is: // // template // class Queue { // public: // using StateId = S; // // // Constructor: may need args (e.g., FST, comparator) for some queues. // Queue(...) override; // // // Returns the head of the queue. // StateId Head() const override; // // // Inserts a state. // void Enqueue(StateId s) override; // // // Removes the head of the queue. // void Dequeue() override; // // // Updates ordering of state s when weight changes, if necessary. // void Update(StateId s) override; // // // Is the queue empty? // bool Empty() const override; // // // Removes all states from the queue. // void Clear() override; // }; // State queue types. enum QueueType { TRIVIAL_QUEUE = 0, // Single state queue. FIFO_QUEUE = 1, // First-in, first-out queue. LIFO_QUEUE = 2, // Last-in, first-out queue. SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue. TOP_ORDER_QUEUE = 4, // Topologically-ordered queue. STATE_ORDER_QUEUE = 5, // State ID-ordered queue. SCC_QUEUE = 6, // Component graph top-ordered meta-queue. AUTO_QUEUE = 7, // Auto-selected queue. OTHER_QUEUE = 8 }; // QueueBase, templated on the StateId, is a virtual base class shared by all // queues considered by AutoQueue. template class QueueBase { public: using StateId = S; virtual ~QueueBase() {} // Concrete implementation. explicit QueueBase(QueueType type) : queue_type_(type), error_(false) {} void SetError(bool error) { error_ = error; } bool Error() const { return error_; } QueueType Type() const { return queue_type_; } // Virtual interface. virtual StateId Head() const = 0; virtual void Enqueue(StateId) = 0; virtual void Dequeue() = 0; virtual void Update(StateId) = 0; virtual bool Empty() const = 0; virtual void Clear() = 0; private: QueueType queue_type_; bool error_; }; // Trivial queue discipline; one may enqueue at most one state at a time. It // can be used for strongly connected components with only one state and no // self-loops. template class TrivialQueue : public QueueBase { public: using StateId = S; TrivialQueue() : QueueBase(TRIVIAL_QUEUE), front_(kNoStateId) {} virtual ~TrivialQueue() = default; StateId Head() const final { return front_; } void Enqueue(StateId s) final { front_ = s; } void Dequeue() final { front_ = kNoStateId; } void Update(StateId) final {} bool Empty() const final { return front_ == kNoStateId; } void Clear() final { front_ = kNoStateId; } private: StateId front_; }; // First-in, first-out queue discipline. // // This is not a final class. template class FifoQueue : public QueueBase { public: using StateId = S; FifoQueue() : QueueBase(FIFO_QUEUE) {} virtual ~FifoQueue() = default; StateId Head() const override { return queue_.back(); } void Enqueue(StateId s) override { queue_.push_front(s); } void Dequeue() override { queue_.pop_back(); } void Update(StateId) override {} bool Empty() const override { return queue_.empty(); } void Clear() override { queue_.clear(); } private: std::deque queue_; }; // Last-in, first-out queue discipline. template class LifoQueue : public QueueBase { public: using StateId = S; LifoQueue() : QueueBase(LIFO_QUEUE) {} virtual ~LifoQueue() = default; StateId Head() const final { return queue_.front(); } void Enqueue(StateId s) final { queue_.push_front(s); } void Dequeue() final { queue_.pop_front(); } void Update(StateId) final {} bool Empty() const final { return queue_.empty(); } void Clear() final { queue_.clear(); } private: std::deque queue_; }; // Shortest-first queue discipline, templated on the StateId and as well as a // comparison functor used to compare two StateIds. If a (single) state's order // changes, it can be reordered in the queue with a call to Update(). If update // is false, call to Update() does not reorder the queue. // // This is not a final class. template class ShortestFirstQueue : public QueueBase { public: using StateId = S; explicit ShortestFirstQueue(Compare comp) : QueueBase(SHORTEST_FIRST_QUEUE), heap_(comp) {} virtual ~ShortestFirstQueue() = default; StateId Head() const override { return heap_.Top(); } void Enqueue(StateId s) override { if (update) { for (StateId i = key_.size(); i <= s; ++i) key_.push_back(kNoStateId); key_[s] = heap_.Insert(s); } else { heap_.Insert(s); } } void Dequeue() override { if (update) { key_[heap_.Pop()] = kNoStateId; } else { heap_.Pop(); } } void Update(StateId s) override { if (!update) return; if (s >= key_.size() || key_[s] == kNoStateId) { Enqueue(s); } else { heap_.Update(key_[s], s); } } bool Empty() const override { return heap_.Empty(); } void Clear() override { heap_.Clear(); if (update) key_.clear(); } const Compare &GetCompare() const { return heap_.GetCompare(); } private: Heap heap_; std::vector key_; }; namespace internal { // Given a vector that maps from states to weights, and a comparison functor // for weights, this class defines a comparison function object between states. template class StateWeightCompare { public: using Weight = typename Less::Weight; StateWeightCompare(const std::vector &weights, const Less &less) : weights_(weights), less_(less) {} bool operator()(const StateId s1, const StateId s2) const { return less_(weights_[s1], weights_[s2]); } private: // Borrowed references. const std::vector &weights_; const Less &less_; }; } // namespace internal // Shortest-first queue discipline, templated on the StateId and Weight, is // specialized to use the weight's natural order for the comparison function. template class NaturalShortestFirstQueue final : public ShortestFirstQueue< S, internal::StateWeightCompare>> { public: using StateId = S; using Compare = internal::StateWeightCompare>; explicit NaturalShortestFirstQueue(const std::vector &distance) : ShortestFirstQueue(Compare(distance, less_)) {} virtual ~NaturalShortestFirstQueue() = default; private: // This is non-static because the constructor for non-idempotent weights will // result in a an error. const NaturalLess less_{}; }; // Topological-order queue discipline, templated on the StateId. States are // ordered in the queue topologically. The FST must be acyclic. template class TopOrderQueue : public QueueBase { public: using StateId = S; // This constructor computes the topological order. It accepts an arc filter // to limit the transitions considered in that computation (e.g., only the // epsilon graph). template TopOrderQueue(const Fst &fst, ArcFilter filter) : QueueBase(TOP_ORDER_QUEUE), front_(0), back_(kNoStateId), order_(0), state_(0) { bool acyclic; TopOrderVisitor top_order_visitor(&order_, &acyclic); DfsVisit(fst, &top_order_visitor, filter); if (!acyclic) { FSTERROR() << "TopOrderQueue: FST is not acyclic"; QueueBase::SetError(true); } state_.resize(order_.size(), kNoStateId); } // This constructor is passed the pre-computed topological order. explicit TopOrderQueue(const std::vector &order) : QueueBase(TOP_ORDER_QUEUE), front_(0), back_(kNoStateId), order_(order), state_(order.size(), kNoStateId) {} virtual ~TopOrderQueue() = default; StateId Head() const final { return state_[front_]; } void Enqueue(StateId s) final { if (front_ > back_) { front_ = back_ = order_[s]; } else if (order_[s] > back_) { back_ = order_[s]; } else if (order_[s] < front_) { front_ = order_[s]; } state_[order_[s]] = s; } void Dequeue() final { state_[front_] = kNoStateId; while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_; } void Update(StateId) final {} bool Empty() const final { return front_ > back_; } void Clear() final { for (StateId s = front_; s <= back_; ++s) state_[s] = kNoStateId; back_ = kNoStateId; front_ = 0; } private: StateId front_; StateId back_; std::vector order_; std::vector state_; }; // State order queue discipline, templated on the StateId. States are ordered in // the queue by state ID. template class StateOrderQueue : public QueueBase { public: using StateId = S; StateOrderQueue() : QueueBase(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {} virtual ~StateOrderQueue() = default; StateId Head() const final { return front_; } void Enqueue(StateId s) final { if (front_ > back_) { front_ = back_ = s; } else if (s > back_) { back_ = s; } else if (s < front_) { front_ = s; } while (enqueued_.size() <= s) enqueued_.push_back(false); enqueued_[s] = true; } void Dequeue() final { enqueued_[front_] = false; while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_; } void Update(StateId) final {} bool Empty() const final { return front_ > back_; } void Clear() final { for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false; front_ = 0; back_ = kNoStateId; } private: StateId front_; StateId back_; std::vector enqueued_; }; // SCC topological-order meta-queue discipline, templated on the StateId and a // queue used inside each SCC. It visits the SCCs of an FST in topological // order. Its constructor is passed the queues to to use within an SCC. template class SccQueue : public QueueBase { public: using StateId = S; // Constructor takes a vector specifying the SCC number per state and a // vector giving the queue to use per SCC number. SccQueue(const std::vector &scc, std::vector> *queue) : QueueBase(SCC_QUEUE), queue_(queue), scc_(scc), front_(0), back_(kNoStateId) {} virtual ~SccQueue() = default; StateId Head() const final { while ((front_ <= back_) && (((*queue_)[front_] && (*queue_)[front_]->Empty()) || (((*queue_)[front_] == nullptr) && ((front_ >= trivial_queue_.size()) || (trivial_queue_[front_] == kNoStateId))))) { ++front_; } if ((*queue_)[front_]) { return (*queue_)[front_]->Head(); } else { return trivial_queue_[front_]; } } void Enqueue(StateId s) final { if (front_ > back_) { front_ = back_ = scc_[s]; } else if (scc_[s] > back_) { back_ = scc_[s]; } else if (scc_[s] < front_) { front_ = scc_[s]; } if ((*queue_)[scc_[s]]) { (*queue_)[scc_[s]]->Enqueue(s); } else { while (trivial_queue_.size() <= scc_[s]) { trivial_queue_.push_back(kNoStateId); } trivial_queue_[scc_[s]] = s; } } void Dequeue() final { if ((*queue_)[front_]) { (*queue_)[front_]->Dequeue(); } else if (front_ < trivial_queue_.size()) { trivial_queue_[front_] = kNoStateId; } } void Update(StateId s) final { if ((*queue_)[scc_[s]]) (*queue_)[scc_[s]]->Update(s); } bool Empty() const final { // Queues SCC number back_ is not empty unless back_ == front_. if (front_ < back_) { return false; } else if (front_ > back_) { return true; } else if ((*queue_)[front_]) { return (*queue_)[front_]->Empty(); } else { return (front_ >= trivial_queue_.size()) || (trivial_queue_[front_] == kNoStateId); } } void Clear() final { for (StateId i = front_; i <= back_; ++i) { if ((*queue_)[i]) { (*queue_)[i]->Clear(); } else if (i < trivial_queue_.size()) { trivial_queue_[i] = kNoStateId; } } front_ = 0; back_ = kNoStateId; } private: std::vector> *queue_; const std::vector &scc_; mutable StateId front_; StateId back_; std::vector trivial_queue_; }; // Automatic queue discipline. It selects a queue discipline for a given FST // based on its properties. template class AutoQueue : public QueueBase { public: using StateId = S; // This constructor takes a state distance vector that, if non-null and if // the Weight type has the path property, will entertain the shortest-first // queue using the natural order w.r.t to the distance. template AutoQueue(const Fst &fst, const std::vector *distance, ArcFilter filter) : QueueBase(AUTO_QUEUE) { using Weight = typename Arc::Weight; using Less = NaturalLess; using Compare = internal::StateWeightCompare; // First checks if the FST is known to have these properties. const auto props = fst.Properties(kAcyclic | kCyclic | kTopSorted | kUnweighted, false); if ((props & kTopSorted) || fst.Start() == kNoStateId) { queue_.reset(new StateOrderQueue()); VLOG(2) << "AutoQueue: using state-order discipline"; } else if (props & kAcyclic) { queue_.reset(new TopOrderQueue(fst, filter)); VLOG(2) << "AutoQueue: using top-order discipline"; } else if ((props & kUnweighted) && (Weight::Properties() & kIdempotent)) { queue_.reset(new LifoQueue()); VLOG(2) << "AutoQueue: using LIFO discipline"; } else { uint64 properties; // Decomposes into strongly-connected components. SccVisitor scc_visitor(&scc_, nullptr, nullptr, &properties); DfsVisit(fst, &scc_visitor, filter); auto nscc = *std::max_element(scc_.begin(), scc_.end()) + 1; std::vector queue_types(nscc); std::unique_ptr less; std::unique_ptr comp; if (distance && (Weight::Properties() & kPath) == kPath) { less.reset(new Less); comp.reset(new Compare(*distance, *less)); } // Finds the queue type to use per SCC. bool unweighted; bool all_trivial; SccQueueType(fst, scc_, &queue_types, filter, less.get(), &all_trivial, &unweighted); // If unweighted and semiring is idempotent, uses LIFO queue. if (unweighted) { queue_.reset(new LifoQueue()); VLOG(2) << "AutoQueue: using LIFO discipline"; return; } // If all the SCC are trivial, the FST is acyclic and the scc number gives // the topological order. if (all_trivial) { queue_.reset(new TopOrderQueue(scc_)); VLOG(2) << "AutoQueue: using top-order discipline"; return; } VLOG(2) << "AutoQueue: using SCC meta-discipline"; queues_.resize(nscc); for (StateId i = 0; i < nscc; ++i) { switch (queue_types[i]) { case TRIVIAL_QUEUE: queues_[i].reset(); VLOG(3) << "AutoQueue: SCC #" << i << ": using trivial discipline"; break; case SHORTEST_FIRST_QUEUE: queues_[i].reset( new ShortestFirstQueue(*comp)); VLOG(3) << "AutoQueue: SCC #" << i << ": using shortest-first discipline"; break; case LIFO_QUEUE: queues_[i].reset(new LifoQueue()); VLOG(3) << "AutoQueue: SCC #" << i << ": using LIFO discipline"; break; case FIFO_QUEUE: default: queues_[i].reset(new FifoQueue()); VLOG(3) << "AutoQueue: SCC #" << i << ": using FIFO discipine"; break; } } queue_.reset(new SccQueue>(scc_, &queues_)); } } virtual ~AutoQueue() = default; StateId Head() const final { return queue_->Head(); } void Enqueue(StateId s) final { queue_->Enqueue(s); } void Dequeue() final { queue_->Dequeue(); } void Update(StateId s) final { queue_->Update(s); } bool Empty() const final { return queue_->Empty(); } void Clear() final { queue_->Clear(); } private: template static void SccQueueType(const Fst &fst, const std::vector &scc, std::vector *queue_types, ArcFilter filter, Less *less, bool *all_trivial, bool *unweighted); std::unique_ptr> queue_; std::vector>> queues_; std::vector scc_; }; // Examines the states in an FST's strongly connected components and determines // which type of queue to use per SCC. Stores result as a vector of QueueTypes // which is assumed to have length equal to the number of SCCs. An arc filter // is used to limit the transitions considered (e.g., only the epsilon graph). // The argument all_trivial is set to true if every queue is the trivial queue. // The argument unweighted is set to true if the semiring is idempotent and all // the arc weights are equal to Zero() or One(). template template void AutoQueue::SccQueueType(const Fst &fst, const std::vector &scc, std::vector *queue_type, ArcFilter filter, Less *less, bool *all_trivial, bool *unweighted) { using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; *all_trivial = true; *unweighted = true; for (StateId i = 0; i < queue_type->size(); ++i) { (*queue_type)[i] = TRIVIAL_QUEUE; } for (StateIterator> sit(fst); !sit.Done(); sit.Next()) { const auto state = sit.Value(); for (ArcIterator> ait(fst, state); !ait.Done(); ait.Next()) { const auto &arc = ait.Value(); if (!filter(arc)) continue; if (scc[state] == scc[arc.nextstate]) { auto &type = (*queue_type)[scc[state]]; if (!less || ((*less)(arc.weight, Weight::One()))) { type = FIFO_QUEUE; } else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) { if (!(Weight::Properties() & kIdempotent) || (arc.weight != Weight::Zero() && arc.weight != Weight::One())) { type = SHORTEST_FIRST_QUEUE; } else { type = LIFO_QUEUE; } } if (type != TRIVIAL_QUEUE) *all_trivial = false; } if (!(Weight::Properties() & kIdempotent) || (arc.weight != Weight::Zero() && arc.weight != Weight::One())) { *unweighted = false; } } } } // An A* estimate is a function object that maps from a state ID to a an // estimate of the shortest distance to the final states. // A trivial A* estimate, yielding a queue which behaves the same in Dijkstra's // algorithm. template struct TrivialAStarEstimate { const Weight &operator()(StateId) const { return Weight::One(); } }; // A non-trivial A* estimate using a vector of the estimated future costs. template class NaturalAStarEstimate { public: NaturalAStarEstimate(const std::vector &beta) : beta_(beta) {} const Weight &operator()(StateId s) const { return beta_[s]; } private: const std::vector &beta_; }; // Given a vector that maps from states to weights representing the shortest // distance from the initial state, a comparison function object between // weights, and an estimate of the shortest distance to the final states, this // class defines a comparison function object between states. template class AStarWeightCompare { public: using StateId = S; using Weight = typename Less::Weight; AStarWeightCompare(const std::vector &weights, const Less &less, const Estimate &estimate) : weights_(weights), less_(less), estimate_(estimate) {} bool operator()(StateId s1, StateId s2) const { const auto w1 = Times(weights_[s1], estimate_(s1)); const auto w2 = Times(weights_[s2], estimate_(s2)); return less_(w1, w2); } const Estimate &GetEstimate() const { return estimate_; } private: const std::vector &weights_; const Less &less_; const Estimate &estimate_; }; // A* queue discipline templated on StateId, Weight, and Estimate. template class NaturalAStarQueue : public ShortestFirstQueue< S, AStarWeightCompare, Estimate>> { public: using StateId = S; using Compare = AStarWeightCompare, Estimate>; NaturalAStarQueue(const std::vector &distance, const Estimate &estimate) : ShortestFirstQueue( Compare(distance, less_, estimate)) {} ~NaturalAStarQueue() = default; private: // This is non-static because the constructor for non-idempotent weights will // result in a an error. const NaturalLess less_{}; }; // A state equivalence class is a function object that maps from a state ID to // an equivalence class (state) ID. The trivial equivalence class maps a state // ID to itself. template struct TrivialStateEquivClass { StateId operator()(StateId s) const { return s; } }; // Distance-based pruning queue discipline: Enqueues a state only when its // shortest distance (so far), as specified by distance, is less than (as // specified by comp) the shortest distance Times() the threshold to any state // in the same equivalence class, as specified by the functor class_func. The // underlying queue discipline is specified by queue. The ownership of queue is // given to this class. // // This is not a final class. template class PruneQueue : public QueueBase { public: using StateId = typename Queue::StateId; using Weight = typename Less::Weight; PruneQueue(const std::vector &distance, Queue *queue, const Less &less, const ClassFnc &class_fnc, Weight threshold) : QueueBase(OTHER_QUEUE), distance_(distance), queue_(queue), less_(less), class_fnc_(class_fnc), threshold_(std::move(threshold)) {} virtual ~PruneQueue() = default; StateId Head() const override { return queue_->Head(); } void Enqueue(StateId s) override { const auto c = class_fnc_(s); if (c >= class_distance_.size()) { class_distance_.resize(c + 1, Weight::Zero()); } if (less_(distance_[s], class_distance_[c])) { class_distance_[c] = distance_[s]; } // Enqueues only if below threshold limit. const auto limit = Times(class_distance_[c], threshold_); if (less_(distance_[s], limit)) queue_->Enqueue(s); } void Dequeue() override { queue_->Dequeue(); } void Update(StateId s) override { const auto c = class_fnc_(s); if (less_(distance_[s], class_distance_[c])) { class_distance_[c] = distance_[s]; } queue_->Update(s); } bool Empty() const override { return queue_->Empty(); } void Clear() override { queue_->Clear(); } private: const std::vector &distance_; // Shortest distance to state. std::unique_ptr queue_; const Less &less_; // Borrowed reference. const ClassFnc &class_fnc_; // Equivalence class functor. Weight threshold_; // Pruning weight threshold. std::vector class_distance_; // Shortest distance to class. }; // Pruning queue discipline (see above) using the weight's natural order for the // comparison function. The ownership of the queue argument is given to this // class. template class NaturalPruneQueue final : public PruneQueue, ClassFnc> { public: using StateId = typename Queue::StateId; NaturalPruneQueue(const std::vector &distance, Queue *queue, const ClassFnc &class_fnc, Weight threshold) : PruneQueue, ClassFnc>( distance, queue, NaturalLess(), class_fnc, threshold) {} virtual ~NaturalPruneQueue() = default; }; // Filter-based pruning queue discipline: enqueues a state only if allowed by // the filter, specified by the state filter functor argument. The underlying // queue discipline is specified by the queue argument. The ownership of the // queue is given to this class. template class FilterQueue : public QueueBase { public: using StateId = typename Queue::StateId; FilterQueue(Queue *queue, const Filter &filter) : QueueBase(OTHER_QUEUE), queue_(queue), filter_(filter) {} virtual ~FilterQueue() = default; StateId Head() const final { return queue_->Head(); } // Enqueues only if allowed by state filter. void Enqueue(StateId s) final { if (filter_(s)) queue_->Enqueue(s); } void Dequeue() final { queue_->Dequeue(); } void Update(StateId s) final {} bool Empty() const final { return queue_->Empty(); } void Clear() final { queue_->Clear(); } private: std::unique_ptr queue_; const Filter &filter_; }; } // namespace fst #endif // FST_QUEUE_H_