// fstext/prune-special-inl.h // Copyright 2014 Johns Hopkins University (Author: Daniel Povey) // Guoguo Chen // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_FSTEXT_PRUNE_SPECIAL_INL_H_ #define KALDI_FSTEXT_PRUNE_SPECIAL_INL_H_ // Do not include this file directly. It is included by prune-special.h #include "fstext/prune-special.h" #include "base/kaldi-error.h" namespace fst { /// This class is used to implement the function PruneSpecial. template class PruneSpecialClass { public: typedef typename Arc::StateId InputStateId; typedef typename Arc::StateId OutputStateId; typedef typename Arc::Weight Weight; typedef typename Arc::Label Label; PruneSpecialClass(const Fst &ifst, VectorFst *ofst, Weight beam, size_t max_states): ifst_(ifst), ofst_(ofst), beam_(beam), max_states_(max_states), best_weight_(Weight::Zero()) { KALDI_ASSERT(beam != Weight::One()); KALDI_ASSERT(queue_.size() == 0); ofst_->DeleteStates(); // make sure it's empty. if (ifst_.Start() == kNoStateId) return; ofst_->SetStart(ProcessState(ifst_.Start(), Weight::One())); while (!queue_.empty()) { Task task = queue_.top(); queue_.pop(); if (Done(task)) break; else ProcessTask(task); } Connect(ofst); if (beam_ != Weight::One()) Prune(ofst, beam_); } struct Task { InputStateId istate; OutputStateId ostate; // could be looked up; this is for speed. size_t position; // arc position, or -1 if final-prob. Weight weight; Task(InputStateId istate, OutputStateId ostate, size_t position, Weight weight): istate(istate), ostate(ostate), position(position), weight(weight) { } bool operator < (const Task &other) const { return Compare(weight, other.weight) < 0; } }; bool Done(const Task &task) { if (beam_ != Weight::One() && best_weight_ != Weight::Zero() && Compare(task.weight, Times(best_weight_, beam_)) < 0) return true; if (max_states_ > 0 && static_cast(ofst_->NumStates()) > max_states_) return true; return false; } // This function assumes "state" has not been seen before, so we need to // create a new output-state for it and add tasks. It returns the // output-state id. "weight" is the best cost from the start-state to this // state. inline OutputStateId ProcessState(InputStateId istate, const Weight &weight) { OutputStateId ostate = ofst_->AddState(); state_map_[istate] = ostate; for (ArcIterator > aiter(ifst_, istate); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); Task new_task(istate, ostate, aiter.Position(), Times(weight, arc.weight)); KALDI_ASSERT(Compare(arc.weight, Weight::One()) != 1); queue_.push(new_task); } Weight final = ifst_.Final(istate); if (final != Weight::Zero()) { Task final_task(istate, ostate, static_cast(-1), Times(weight, final)); KALDI_ASSERT(Compare(final, Weight::One()) != 1); queue_.push(final_task); } return ostate; } // Returns the output-state id corresponding to "istate". This assumes we are // processing a task corresponding to an arc to "istate", and the cost from // the start-state to this state is "weight". Since we process tasks in // order, if this is the first time we see this istate, then this is the best // cost from the start-state to this state, and it can be used in setting the // priority costs in ProcessState(). inline OutputStateId GetOutputStateId(InputStateId istate, const Weight &weight) { typedef typename unordered_map::iterator IterType; IterType iter = state_map_.find(istate); if (iter == state_map_.end()) return ProcessState(istate, weight); else return iter->second; } void ProcessTask(const Task &task) { if (task.position == static_cast(-1)) { ofst_->SetFinal(task.ostate, ifst_.Final(task.istate)); if (best_weight_ == Weight::Zero()) best_weight_ = task.weight; // best-path cost through FST, used for // beam-pruning. } else { ArcIterator > aiter(ifst_, task.istate); aiter.Seek(task.position); // if we spend most of our time here, we may // need to store the arc in the Task. const Arc &arc = aiter.Value(); InputStateId next_istate = arc.nextstate; OutputStateId next_ostate = GetOutputStateId(next_istate, task.weight); Arc oarc(arc.ilabel, arc.olabel, arc.weight, next_ostate); ofst_->AddArc(task.ostate, oarc); } } private: const Fst &ifst_; VectorFst *ofst_; Weight beam_; size_t max_states_; unordered_map state_map_; std::priority_queue queue_; Weight best_weight_; // if not Zero(), then we have now processed a successful path // through ifst_, and this is the weight. }; template void PruneSpecial(const Fst &ifst, VectorFst *ofst, typename Arc::Weight beam, size_t max_states) { PruneSpecialClass c(ifst, ofst, beam, max_states); } } #endif