Blame view
src/fstext/prune-special-inl.h
6.01 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
// fstext/prune-special-inl.h // Copyright 2014 Johns Hopkins University (Author: Daniel Povey) // Guoguo Chen // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_FSTEXT_PRUNE_SPECIAL_INL_H_ #define KALDI_FSTEXT_PRUNE_SPECIAL_INL_H_ // Do not include this file directly. It is included by prune-special.h #include "fstext/prune-special.h" #include "base/kaldi-error.h" namespace fst { /// This class is used to implement the function PruneSpecial. template<class Arc> class PruneSpecialClass { public: typedef typename Arc::StateId InputStateId; typedef typename Arc::StateId OutputStateId; typedef typename Arc::Weight Weight; typedef typename Arc::Label Label; PruneSpecialClass(const Fst<Arc> &ifst, VectorFst<Arc> *ofst, Weight beam, size_t max_states): ifst_(ifst), ofst_(ofst), beam_(beam), max_states_(max_states), best_weight_(Weight::Zero()) { KALDI_ASSERT(beam != Weight::One()); KALDI_ASSERT(queue_.size() == 0); ofst_->DeleteStates(); // make sure it's empty. if (ifst_.Start() == kNoStateId) return; ofst_->SetStart(ProcessState(ifst_.Start(), Weight::One())); while (!queue_.empty()) { Task task = queue_.top(); queue_.pop(); if (Done(task)) break; else ProcessTask(task); } Connect(ofst); if (beam_ != Weight::One()) Prune(ofst, beam_); } struct Task { InputStateId istate; OutputStateId ostate; // could be looked up; this is for speed. size_t position; // arc position, or -1 if final-prob. Weight weight; Task(InputStateId istate, OutputStateId ostate, size_t position, Weight weight): istate(istate), ostate(ostate), position(position), weight(weight) { } bool operator < (const Task &other) const { return Compare(weight, other.weight) < 0; } }; bool Done(const Task &task) { if (beam_ != Weight::One() && best_weight_ != Weight::Zero() && Compare(task.weight, Times(best_weight_, beam_)) < 0) return true; if (max_states_ > 0 && static_cast<size_t>(ofst_->NumStates()) > max_states_) return true; return false; } // This function assumes "state" has not been seen before, so we need to // create a new output-state for it and add tasks. It returns the // output-state id. "weight" is the best cost from the start-state to this // state. inline OutputStateId ProcessState(InputStateId istate, const Weight &weight) { OutputStateId ostate = ofst_->AddState(); state_map_[istate] = ostate; for (ArcIterator<Fst<Arc> > aiter(ifst_, istate); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); Task new_task(istate, ostate, aiter.Position(), Times(weight, arc.weight)); KALDI_ASSERT(Compare(arc.weight, Weight::One()) != 1); queue_.push(new_task); } Weight final = ifst_.Final(istate); if (final != Weight::Zero()) { Task final_task(istate, ostate, static_cast<size_t>(-1), Times(weight, final)); KALDI_ASSERT(Compare(final, Weight::One()) != 1); queue_.push(final_task); } return ostate; } // Returns the output-state id corresponding to "istate". This assumes we are // processing a task corresponding to an arc to "istate", and the cost from // the start-state to this state is "weight". Since we process tasks in // order, if this is the first time we see this istate, then this is the best // cost from the start-state to this state, and it can be used in setting the // priority costs in ProcessState(). inline OutputStateId GetOutputStateId(InputStateId istate, const Weight &weight) { typedef typename unordered_map<InputStateId, OutputStateId>::iterator IterType; IterType iter = state_map_.find(istate); if (iter == state_map_.end()) return ProcessState(istate, weight); else return iter->second; } void ProcessTask(const Task &task) { if (task.position == static_cast<size_t>(-1)) { ofst_->SetFinal(task.ostate, ifst_.Final(task.istate)); if (best_weight_ == Weight::Zero()) best_weight_ = task.weight; // best-path cost through FST, used for // beam-pruning. } else { ArcIterator<Fst<Arc> > aiter(ifst_, task.istate); aiter.Seek(task.position); // if we spend most of our time here, we may // need to store the arc in the Task. const Arc &arc = aiter.Value(); InputStateId next_istate = arc.nextstate; OutputStateId next_ostate = GetOutputStateId(next_istate, task.weight); Arc oarc(arc.ilabel, arc.olabel, arc.weight, next_ostate); ofst_->AddArc(task.ostate, oarc); } } private: const Fst<Arc> &ifst_; VectorFst<Arc> *ofst_; Weight beam_; size_t max_states_; unordered_map<InputStateId, OutputStateId> state_map_; std::priority_queue<Task> queue_; Weight best_weight_; // if not Zero(), then we have now processed a successful path // through ifst_, and this is the weight. }; template<class Arc> void PruneSpecial(const Fst<Arc> &ifst, VectorFst<Arc> *ofst, typename Arc::Weight beam, size_t max_states) { PruneSpecialClass<Arc> c(ifst, ofst, beam, max_states); } } #endif |