// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Functions implementing pruning. #ifndef FST_PRUNE_H_ #define FST_PRUNE_H_ #include #include #include #include #include #include #include namespace fst { namespace internal { template class PruneCompare { public: PruneCompare(const std::vector &idistance, const std::vector &fdistance) : idistance_(idistance), fdistance_(fdistance) {} bool operator()(const StateId x, const StateId y) const { const auto wx = Times(IDistance(x), FDistance(x)); const auto wy = Times(IDistance(y), FDistance(y)); return less_(wx, wy); } private: Weight IDistance(const StateId s) const { return s < idistance_.size() ? idistance_[s] : Weight::Zero(); } Weight FDistance(const StateId s) const { return s < fdistance_.size() ? fdistance_[s] : Weight::Zero(); } const std::vector &idistance_; const std::vector &fdistance_; NaturalLess less_; }; } // namespace internal template struct PruneOptions { using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; PruneOptions(const Weight &weight_threshold, StateId state_threshold, ArcFilter filter, std::vector *distance = nullptr, float delta = kDelta, bool threshold_initial = false) : weight_threshold(std::move(weight_threshold)), state_threshold(state_threshold), filter(std::move(filter)), distance(distance), delta(delta), threshold_initial(threshold_initial) {} // Pruning weight threshold. Weight weight_threshold; // Pruning state threshold. StateId state_threshold; // Arc filter. ArcFilter filter; // If non-zero, passes in pre-computed shortest distance to final states. const std::vector *distance; // Determines the degree of convergence required when computing shortest // distances. float delta; // Determines if the shortest path weight is left (true) or right // (false) multiplied by the threshold to get the limit for // keeping a state or arc (matters if the semiring is not // commutative). bool threshold_initial; }; // Pruning algorithm: this version modifies its input and it takes an options // class as an argument. After pruning the FST contains states and arcs that // belong to a successful path in the FST whose weight is no more than the // weight of the shortest path Times() the provided weight threshold. When the // state threshold is not kNoStateId, the output FST is further restricted to // have no more than the number of states in opts.state_threshold. Weights must // have the path property. The weight of any cycle needs to be bounded; i.e., // // Plus(weight, Weight::One()) == Weight::One() template ::type * = nullptr> void Prune(MutableFst *fst, const PruneOptions &opts) { using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; using StateHeap = Heap>; auto ns = fst->NumStates(); if (ns < 1) return; std::vector idistance(ns, Weight::Zero()); std::vector tmp; if (!opts.distance) { tmp.reserve(ns); ShortestDistance(*fst, &tmp, true, opts.delta); } const auto *fdistance = opts.distance ? opts.distance : &tmp; if ((opts.state_threshold == 0) || (fdistance->size() <= fst->Start()) || ((*fdistance)[fst->Start()] == Weight::Zero())) { fst->DeleteStates(); return; } internal::PruneCompare compare(idistance, *fdistance); StateHeap heap(compare); std::vector visited(ns, false); std::vector enqueued(ns, StateHeap::kNoKey); std::vector dead; dead.push_back(fst->AddState()); NaturalLess less; auto s = fst->Start(); const auto limit = opts.threshold_initial ? Times(opts.weight_threshold, (*fdistance)[s]) : Times((*fdistance)[s], opts.weight_threshold); StateId num_visited = 0; if (!less(limit, (*fdistance)[s])) { idistance[s] = Weight::One(); enqueued[s] = heap.Insert(s); ++num_visited; } while (!heap.Empty()) { s = heap.Top(); heap.Pop(); enqueued[s] = StateHeap::kNoKey; visited[s] = true; if (less(limit, Times(idistance[s], fst->Final(s)))) { fst->SetFinal(s, Weight::Zero()); } for (MutableArcIterator> aiter(fst, s); !aiter.Done(); aiter.Next()) { auto arc = aiter.Value(); // Copy intended. if (!opts.filter(arc)) continue; const auto weight = Times(Times(idistance[s], arc.weight), arc.nextstate < fdistance->size() ? (*fdistance)[arc.nextstate] : Weight::Zero()); if (less(limit, weight)) { arc.nextstate = dead[0]; aiter.SetValue(arc); continue; } if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) { idistance[arc.nextstate] = Times(idistance[s], arc.weight); } if (visited[arc.nextstate]) continue; if ((opts.state_threshold != kNoStateId) && (num_visited >= opts.state_threshold)) { continue; } if (enqueued[arc.nextstate] == StateHeap::kNoKey) { enqueued[arc.nextstate] = heap.Insert(arc.nextstate); ++num_visited; } else { heap.Update(enqueued[arc.nextstate], arc.nextstate); } } } for (StateId i = 0; i < visited.size(); ++i) { if (!visited[i]) dead.push_back(i); } fst->DeleteStates(dead); } template ::type * = nullptr> void Prune(MutableFst *fst, const PruneOptions &) { FSTERROR() << "Prune: Weight needs to have the path property: " << Arc::Weight::Type(); fst->SetProperties(kError, kError); } // Pruning algorithm: this version modifies its input and takes the // pruning threshold as an argument. It deletes states and arcs in the // FST that do not belong to a successful path whose weight is more // than the weight of the shortest path Times() the provided weight // threshold. When the state threshold is not kNoStateId, the output // FST is further restricted to have no more than the number of states // in opts.state_threshold. Weights must have the path property. The // weight of any cycle needs to be bounded; i.e., // // Plus(weight, Weight::One()) == Weight::One() template void Prune(MutableFst *fst, typename Arc::Weight weight_threshold, typename Arc::StateId state_threshold = kNoStateId, float delta = kDelta) { const PruneOptions> opts( weight_threshold, state_threshold, AnyArcFilter(), nullptr, delta); Prune(fst, opts); } // Pruning algorithm: this version writes the pruned input FST to an // output MutableFst and it takes an options class as an argument. The // output FST contains states and arcs that belong to a successful // path in the input FST whose weight is more than the weight of the // shortest path Times() the provided weight threshold. When the state // threshold is not kNoStateId, the output FST is further restricted // to have no more than the number of states in // opts.state_threshold. Weights have the path property. The weight // of any cycle needs to be bounded; i.e., // // Plus(weight, Weight::One()) == Weight::One() template ::value>::type * = nullptr> void Prune(const Fst &ifst, MutableFst *ofst, const PruneOptions &opts) { using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; using StateHeap = Heap>; ofst->DeleteStates(); ofst->SetInputSymbols(ifst.InputSymbols()); ofst->SetOutputSymbols(ifst.OutputSymbols()); if (ifst.Start() == kNoStateId) return; NaturalLess less; if (less(opts.weight_threshold, Weight::One()) || (opts.state_threshold == 0)) { return; } std::vector idistance; std::vector tmp; if (!opts.distance) ShortestDistance(ifst, &tmp, true, opts.delta); const auto *fdistance = opts.distance ? opts.distance : &tmp; if ((fdistance->size() <= ifst.Start()) || ((*fdistance)[ifst.Start()] == Weight::Zero())) { return; } internal::PruneCompare compare(idistance, *fdistance); StateHeap heap(compare); std::vector copy; std::vector enqueued; std::vector visited; auto s = ifst.Start(); const auto limit = opts.threshold_initial ? Times(opts.weight_threshold, (*fdistance)[s]) : Times((*fdistance)[s], opts.weight_threshold); while (copy.size() <= s) copy.push_back(kNoStateId); copy[s] = ofst->AddState(); ofst->SetStart(copy[s]); while (idistance.size() <= s) idistance.push_back(Weight::Zero()); idistance[s] = Weight::One(); while (enqueued.size() <= s) { enqueued.push_back(StateHeap::kNoKey); visited.push_back(false); } enqueued[s] = heap.Insert(s); while (!heap.Empty()) { s = heap.Top(); heap.Pop(); enqueued[s] = StateHeap::kNoKey; visited[s] = true; if (!less(limit, Times(idistance[s], ifst.Final(s)))) { ofst->SetFinal(copy[s], ifst.Final(s)); } for (ArcIterator> aiter(ifst, s); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); if (!opts.filter(arc)) continue; const auto weight = Times(Times(idistance[s], arc.weight), arc.nextstate < fdistance->size() ? (*fdistance)[arc.nextstate] : Weight::Zero()); if (less(limit, weight)) continue; if ((opts.state_threshold != kNoStateId) && (ofst->NumStates() >= opts.state_threshold)) { continue; } while (idistance.size() <= arc.nextstate) { idistance.push_back(Weight::Zero()); } if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) { idistance[arc.nextstate] = Times(idistance[s], arc.weight); } while (copy.size() <= arc.nextstate) copy.push_back(kNoStateId); if (copy[arc.nextstate] == kNoStateId) { copy[arc.nextstate] = ofst->AddState(); } ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight, copy[arc.nextstate])); while (enqueued.size() <= arc.nextstate) { enqueued.push_back(StateHeap::kNoKey); visited.push_back(false); } if (visited[arc.nextstate]) continue; if (enqueued[arc.nextstate] == StateHeap::kNoKey) { enqueued[arc.nextstate] = heap.Insert(arc.nextstate); } else { heap.Update(enqueued[arc.nextstate], arc.nextstate); } } } } template ::value>::type * = nullptr> void Prune(const Fst &, MutableFst *ofst, const PruneOptions &) { FSTERROR() << "Prune: Weight needs to have the path property: " << Arc::Weight::Type(); ofst->SetProperties(kError, kError); } // Pruning algorithm: this version writes the pruned input FST to an // output MutableFst and simply takes the pruning threshold as an // argument. The output FST contains states and arcs that belong to a // successful path in the input FST whose weight is no more than the // weight of the shortest path Times() the provided weight // threshold. When the state threshold is not kNoStateId, the output // FST is further restricted to have no more than the number of states // in opts.state_threshold. Weights must have the path property. The // weight of any cycle needs to be bounded; i.e., // // Plus(weight, Weight::One()) = Weight::One(); template void Prune(const Fst &ifst, MutableFst *ofst, typename Arc::Weight weight_threshold, typename Arc::StateId state_threshold = kNoStateId, float delta = kDelta) { const PruneOptions> opts( weight_threshold, state_threshold, AnyArcFilter(), nullptr, delta); Prune(ifst, ofst, opts); } } // namespace fst #endif // FST_PRUNE_H_