Blame view
src/decoder/lattice-simple-decoder.h
14.1 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 |
// decoder/lattice-simple-decoder.h // Copyright 2009-2012 Microsoft Corporation // 2012-2014 Johns Hopkins University (Author: Daniel Povey) // 2014 Guoguo Chen // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_DECODER_LATTICE_SIMPLE_DECODER_H_ #define KALDI_DECODER_LATTICE_SIMPLE_DECODER_H_ #include "util/stl-utils.h" #include "fst/fstlib.h" #include "itf/decodable-itf.h" #include "fstext/fstext-lib.h" #include "lat/determinize-lattice-pruned.h" #include "lat/kaldi-lattice.h" #include <algorithm> namespace kaldi { struct LatticeSimpleDecoderConfig { BaseFloat beam; BaseFloat lattice_beam; int32 prune_interval; bool determinize_lattice; // not inspected by this class... used in // command-line program. bool prune_lattice; BaseFloat beam_ratio; BaseFloat prune_scale; // Note: we don't make this configurable on the command line, // it's not a very important parameter. It affects the // algorithm that prunes the tokens as we go. fst::DeterminizeLatticePhonePrunedOptions det_opts; LatticeSimpleDecoderConfig(): beam(16.0), lattice_beam(10.0), prune_interval(25), determinize_lattice(true), beam_ratio(0.9), prune_scale(0.1) { } void Register(OptionsItf *opts) { det_opts.Register(opts); opts->Register("beam", &beam, "Decoding beam."); opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam"); opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " "which to prune tokens"); opts->Register("determinize-lattice", &determinize_lattice, "If true, " "determinize the lattice (in a special sense, keeping only " "best pdf-sequence for each word-sequence)."); } void Check() const { KALDI_ASSERT(beam > 0.0 && lattice_beam > 0.0 && prune_interval > 0); } }; /** Simplest possible decoder, included largely for didactic purposes and as a means to debug more highly optimized decoders. See \ref decoders_simple for more information. */ class LatticeSimpleDecoder { public: typedef fst::StdArc Arc; typedef Arc::Label Label; typedef Arc::StateId StateId; typedef Arc::Weight Weight; // instantiate this class onece for each thing you have to decode. LatticeSimpleDecoder(const fst::Fst<fst::StdArc> &fst, const LatticeSimpleDecoderConfig &config): fst_(fst), config_(config), num_toks_(0) { config.Check(); } ~LatticeSimpleDecoder() { ClearActiveTokens(); } const LatticeSimpleDecoderConfig &GetOptions() const { return config_; } // Returns true if any kind of traceback is available (not necessarily from // a final state). bool Decode(DecodableInterface *decodable); /// says whether a final-state was active on the last frame. If it was not, the /// lattice (or traceback) will end with states that are not final-states. bool ReachedFinal() const { return FinalRelativeCost() != std::numeric_limits<BaseFloat>::infinity(); } /// InitDecoding initializes the decoding, and should only be used if you /// intend to call AdvanceDecoding(). If you call Decode(), you don't need /// to call this. You can call InitDecoding if you have already decoded an /// utterance and want to start with a new utterance. void InitDecoding(); /// This function may be optionally called after AdvanceDecoding(), when you /// do not plan to decode any further. It does an extra pruning step that /// will help to prune the lattices output by GetLattice and (particularly) /// GetRawLattice more accurately, particularly toward the end of the /// utterance. It does this by using the final-probs in pruning (if any /// final-state survived); it also does a final pruning step that visits all /// states (the pruning that is done during decoding may fail to prune states /// that are within kPruningScale = 0.1 outside of the beam). If you call /// this, you cannot call AdvanceDecoding again (it will fail), and you /// cannot call GetLattice() and related functions with use_final_probs = /// false. /// Used to be called PruneActiveTokensFinal(). void FinalizeDecoding(); /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives /// more information. It returns the difference between the best (final-cost /// plus cost) of any token on the final frame, and the best cost of any token /// on the final frame. If it is infinity it means no final-states were /// present on the final frame. It will usually be nonnegative. If it not /// too positive (e.g. < 5 is my first guess, but this is not tested) you can /// take it as a good indication that we reached the final-state with /// reasonable likelihood. BaseFloat FinalRelativeCost() const; // Outputs an FST corresponding to the single best path // through the lattice. Returns true if result is nonempty // (using the return status is deprecated, it will become void). // If "use_final_probs" is true AND we reached the final-state // of the graph then it will include those as final-probs, else // it will treat all final-probs as one. bool GetBestPath(Lattice *lat, bool use_final_probs = true) const; // Outputs an FST corresponding to the raw, state-level // tracebacks. Returns true if result is nonempty // (using the return status is deprecated, it will become void). // If "use_final_probs" is true AND we reached the final-state // of the graph then it will include those as final-probs, else // it will treat all final-probs as one. bool GetRawLattice(Lattice *lat, bool use_final_probs = true) const; // This function is now deprecated, since now we do determinization from // outside the LatticeTrackingDecoder class. // Outputs an FST corresponding to the lattice-determinized // lattice (one path per word sequence). [will become deprecated, // users should determinize themselves.] bool GetLattice(CompactLattice *clat, bool use_final_probs = true) const; inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } private: struct Token; // ForwardLinks are the links from a token to a token on the next frame. // or sometimes on the current frame (for input-epsilon links). struct ForwardLink { Token *next_tok; // the next token [or NULL if represents final-state] Label ilabel; // ilabel on link. Label olabel; // olabel on link. BaseFloat graph_cost; // graph cost of traversing link (contains LM, etc.) BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing link ForwardLink *next; // next in singly-linked list of forward links from a // token. ForwardLink(Token *next_tok, Label ilabel, Label olabel, BaseFloat graph_cost, BaseFloat acoustic_cost, ForwardLink *next): next_tok(next_tok), ilabel(ilabel), olabel(olabel), graph_cost(graph_cost), acoustic_cost(acoustic_cost), next(next) { } }; // Token is what's resident in a particular state at a particular time. // In this decoder a Token actually contains *forward* links. // When first created, a Token just has the (total) cost. We add forward // links from it when we process the next frame. struct Token { BaseFloat tot_cost; // would equal weight.Value()... cost up to this point. BaseFloat extra_cost; // >= 0. After calling PruneForwardLinks, this equals // the minimum difference between the cost of the best path this is on, // and the cost of the absolute best path, under the assumption // that any of the currently active states at the decoding front may // eventually succeed (e.g. if you were to take the currently active states // one by one and compute this difference, and then take the minimum). ForwardLink *links; // Head of singly linked list of ForwardLinks Token *next; // Next in list of tokens for this frame. Token(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLink *links, Token *next): tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next) { } Token() {} void DeleteForwardLinks() { ForwardLink *l = links, *m; while (l != NULL) { m = l->next; delete l; l = m; } links = NULL; } }; // head and tail of per-frame list of Tokens (list is in topological order), // and something saying whether we ever pruned it using PruneForwardLinks. struct TokenList { Token *toks; bool must_prune_forward_links; bool must_prune_tokens; TokenList(): toks(NULL), must_prune_forward_links(true), must_prune_tokens(true) { } }; // FindOrAddToken either locates a token in cur_toks_, or if necessary inserts a new, // empty token (i.e. with no forward links) for the current frame. [note: it's // inserted if necessary into cur_toks_ and also into the singly linked list // of tokens active on this frame (whose head is at active_toks_[frame]). // // Returns the Token pointer. Sets "changed" (if non-NULL) to true // if the token was newly created or the cost changed. inline Token *FindOrAddToken(StateId state, int32 frame_plus_one, BaseFloat tot_cost, bool emitting, bool *changed); // delta is the amount by which the extra_costs must // change before it sets "extra_costs_changed" to true. If delta is larger, // we'll tend to go back less far toward the beginning of the file. void PruneForwardLinks(int32 frame, bool *extra_costs_changed, bool *links_pruned, BaseFloat delta); // PruneForwardLinksFinal is a version of PruneForwardLinks that we call // on the final frame. If there are final tokens active, it uses the final-probs // for pruning, otherwise it treats all tokens as final. void PruneForwardLinksFinal(); // Prune away any tokens on this frame that have no forward links. [we don't do // this in PruneForwardLinks because it would give us a problem with dangling // pointers]. void PruneTokensForFrame(int32 frame); // Go backwards through still-alive tokens, pruning them if the // forward+backward cost is more than lat_beam away from the best path. It's // possible to prove that this is "correct" in the sense that we won't lose // anything outside of lat_beam, regardless of what happens in the future. // delta controls when it considers a cost to have changed enough to continue // going backward and propagating the change. larger delta -> will recurse // less far. void PruneActiveTokens(BaseFloat delta); void ProcessEmitting(DecodableInterface *decodable); void ProcessNonemitting(); void ClearActiveTokens(); // a cleanup routine, at utt end/begin // This function computes the final-costs for tokens active on the final // frame. It outputs to final-costs, if non-NULL, a map from the Token* // pointer to the final-prob of the corresponding state, or zero for all states if // none were final. It outputs to final_relative_cost, if non-NULL, the // difference between the best forward-cost including the final-prob cost, and // the best forward-cost without including the final-prob cost (this will // usually be positive), or infinity if there were no final-probs. It outputs // to final_best_cost, if non-NULL, the lowest for any token t active on the // final frame, of t + final-cost[t], where final-cost[t] is the final-cost // in the graph of the state corresponding to token t, or zero if there // were no final-probs active on the final frame. // You cannot call this after FinalizeDecoding() has been called; in that // case you should get the answer from class-member variables. void ComputeFinalCosts(unordered_map<Token*, BaseFloat> *final_costs, BaseFloat *final_relative_cost, BaseFloat *final_best_cost) const; // PruneCurrentTokens deletes the tokens from the "toks" map, but not // from the active_toks_ list, which could cause dangling forward pointers // (will delete it during regular pruning operation). void PruneCurrentTokens(BaseFloat beam, unordered_map<StateId, Token*> *toks); unordered_map<StateId, Token*> cur_toks_; unordered_map<StateId, Token*> prev_toks_; std::vector<TokenList> active_toks_; // Lists of tokens, indexed by // frame_plus_one const fst::Fst<fst::StdArc> &fst_; LatticeSimpleDecoderConfig config_; int32 num_toks_; // current total #toks allocated... bool warned_; /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, /// calling this is optional]. If true, it's forbidden to decode more. Also, /// if this is set, then the output of ComputeFinalCosts() is in the next /// three variables. The reason we need to do this is that after /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some /// of the tokens on the last frame are freed, so we free the list from /// cur_toks_ to avoid having dangling pointers hanging around. bool decoding_finalized_; /// For the meaning of the next 3 variables, see the comment for /// decoding_finalized_ above., and ComputeFinalCosts(). unordered_map<Token*, BaseFloat> final_costs_; BaseFloat final_relative_cost_; BaseFloat final_best_cost_; }; } // end namespace kaldi. #endif |