// lat/sausages.h // Copyright 2012 Johns Hopkins University (Author: Daniel Povey) // 2015 Guoguo Chen // 2019 Dogan Can // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_LAT_SAUSAGES_H_ #define KALDI_LAT_SAUSAGES_H_ #include #include #include "base/kaldi-common.h" #include "util/common-utils.h" #include "fstext/fstext-lib.h" #include "lat/kaldi-lattice.h" namespace kaldi { /// The implementation of the Minimum Bayes Risk decoding method described in /// "Minimum Bayes Risk decoding and system combination based on a recursion for /// edit distance", Haihua Xu, Daniel Povey, Lidia Mangu and Jie Zhu, Computer /// Speech and Language, 2011 /// This is a slightly more principled way to do Minimum Bayes Risk (MBR) decoding /// than the standard "Confusion Network" method. Note: MBR decoding aims to /// minimize the expected word error rate, assuming the lattice encodes the /// true uncertainty about what was spoken; standard Viterbi decoding gives the /// most likely utterance, which corresponds to minimizing the expected sentence /// error rate. /// /// In addition to giving the MBR output, we also provide a way to get a /// "Confusion Network" or informally "sausage"-like structure. This is a /// linear sequence of bins, and in each bin, there is a distribution over /// words (or epsilon, meaning no word). This is useful for estimating /// confidence. Note: due to the way these sausages are made, typically there /// will be, between each bin representing a high-confidence word, a bin /// in which epsilon (no word) is the most likely word. Inside these bins /// is where we put possible insertions. struct MinimumBayesRiskOptions { /// Boolean configuration parameter: if true, we actually update the hypothesis /// to do MBR decoding (if false, our output is the MAP decoded output, but we /// output the stats too (i.e. the confidences)). bool decode_mbr; /// Boolean configuration parameter: if true, the 1-best path will 'keep' the bins, bool print_silence; MinimumBayesRiskOptions() : decode_mbr(true), print_silence(false) { } void Register(OptionsItf *opts) { opts->Register("decode-mbr", &decode_mbr, "If true, do Minimum Bayes Risk " "decoding (else, Maximum a Posteriori)"); opts->Register("print-silence", &print_silence, "Keep the inter-word '' " "bins in the 1-best output (ctm, can be a 'silence' or a 'deleted' word)"); } }; /// This class does the word-level Minimum Bayes Risk computation, and gives you /// either the 1-best MBR output together with the expected Bayes Risk, /// or a sausage-like structure. class MinimumBayesRisk { public: /// Initialize with compact lattice-- any acoustic scaling etc., is assumed /// to have been done already. /// This does the whole computation. You get the output with /// GetOneBest(), GetBayesRisk(), and GetSausageStats(). MinimumBayesRisk(const CompactLattice &clat, MinimumBayesRiskOptions opts = MinimumBayesRiskOptions()); // Uses the provided as instead of using the lattice best path. // Note that the default value of opts.decode_mbr is true. If you provide 1-best // hypothesis from MAP decoding, the output ctm from MBR decoding may be // mismatched with the provided ( would be used as the starting // point of optimization). MinimumBayesRisk(const CompactLattice &clat, const std::vector &words, MinimumBayesRiskOptions opts = MinimumBayesRiskOptions()); // Uses the provided as and of bins instead of using the lattice best path. // Note that the default value of opts.decode_mbr is true. If you provide 1-best // hypothesis from MAP decoding, the output ctm from MBR decoding may be // mismatched with the provided ( would be used as the starting // point of optimization). MinimumBayesRisk(const CompactLattice &clat, const std::vector &words, const std::vector > ×, MinimumBayesRiskOptions opts = MinimumBayesRiskOptions()); const std::vector &GetOneBest() const { // gets one-best (with no epsilons) return R_; } const std::vector > > GetTimes() const { return times_; // returns average (start,end) times for each word in each // bin. These are raw averages without any processing, i.e. time intervals // from different bins can overlap. } const std::vector > GetSausageTimes() const { return sausage_times_; // returns average (start,end) times for each bin. // This is typically the weighted average of the times in GetTimes() but can // be slightly different if the times for the bins overlap, in which case // the times returned by this method do not overlap unlike the times // returned by GetTimes(). } const std::vector > &GetOneBestTimes() const { return one_best_times_; // returns average (start,end) times for each word // corresponding to an entry in the one-best output. This is typically the // appropriate subset of the times in GetTimes() but can be slightly // different if the times for the one-best words overlap, in which case // the times returned by this method do not overlap unlike the times // returned by GetTimes(). } /// Outputs the confidences for the one-best transcript. const std::vector &GetOneBestConfidences() const { return one_best_confidences_; } /// Returns the expected WER over this sentence (assuming model correctness). BaseFloat GetBayesRisk() const { return L_; } const std::vector > > &GetSausageStats() const { return gamma_; } private: void PrepareLatticeAndInitStats(CompactLattice *clat); /// Minimum-Bayes-Risk Decode. Top-level algorithm. Figure 6 of the paper. void MbrDecode(); /// Without the 'penalize' argument this gives us the basic edit-distance /// function l(a,b), as in the paper. /// With the 'penalize' argument it can be interpreted as the edit distance /// plus the 'delta' from the paper, except that we make a kind of conceptual /// bug-fix and only apply the delta if the edit-distance was not already /// zero. This bug-fix was necessary in order to force all the stats to show /// up, that should show up, and applying the bug-fix makes the sausage stats /// significantly less sparse. inline double l(int32 a, int32 b, bool penalize = false) { if (a == b) return 0.0; else return (penalize ? 1.0 + delta() : 1.0); } /// returns r_q, in one-based indexing, as in the paper. inline int32 r(int32 q) { return R_[q-1]; } /// Figure 4 of the paper; called from AccStats (Fig. 5) double EditDistance(int32 N, int32 Q, Vector &alpha, Matrix &alpha_dash, Vector &alpha_dash_arc); /// Figure 5 of the paper. Outputs to gamma_ and L_. void AccStats(); /// Removes epsilons (symbol 0) from a vector static void RemoveEps(std::vector *vec); // Ensures that between each word in "vec" and at the beginning and end, is // epsilon (0). (But if no words in vec, just one epsilon) static void NormalizeEps(std::vector *vec); // delta() is a constant used in the algorithm, which penalizes // the use of certain epsilon transitions in the edit-distance which would cause // words not to show up in the accumulated edit-distance statistics. // There has been a conceptual bug-fix versus the way it was presented in // the paper: we now add delta only if the edit-distance was not already // zero. static inline BaseFloat delta() { return 1.0e-05; } /// Function used to increment map. static inline void AddToMap(int32 i, double d, std::map *gamma) { if (d == 0) return; std::pair pr(i, d); std::pair::iterator, bool> ret = gamma->insert(pr); if (!ret.second) // not inserted, so add to contents. ret.first->second += d; } struct Arc { int32 word; int32 start_node; int32 end_node; BaseFloat loglike; }; MinimumBayesRiskOptions opts_; /// Arcs in the topologically sorted acceptor form of the word-level lattice, /// with one final-state. Contains (word-symbol, log-likelihood on arc == /// negated cost). Indexed from zero. std::vector arcs_; /// For each node in the lattice, a list of arcs entering that node. Indexed /// from 1 (first node == 1). std::vector > pre_; std::vector state_times_; // time of each state in the word lattice, // indexed from 1 (same index as into pre_) std::vector R_; // current 1-best word sequence, normalized to have // epsilons between each word and at the beginning and end. R in paper... // caution: indexed from zero, not from 1 as in paper. double L_; // current averaged edit-distance between lattice and R_. // \hat{L} in paper. std::vector > > gamma_; // The stats we accumulate; these are pairs of (posterior, word-id), and note // that word-id may be epsilon. Caution: indexed from zero, not from 1 as in // paper. We sort in reverse order on the second member (posterior), so more // likely word is first. std::vector > > times_; // The average start and end times for words in each confusion-network bin. // This is like an average over arcs, of the tau_b and tau_e quantities in // Appendix C of the paper. Indexed from zero, like gamma_ and R_. std::vector > sausage_times_; // The average start and end times for each confusion-network bin. This // is like an average over words, of the tau_b and tau_e quantities in // Appendix C of the paper. Indexed from zero, like gamma_ and R_. std::vector > one_best_times_; // The average start and end times for words in the one best output. This // is like an average over the arcs, of the tau_b and tau_e quantities in // Appendix C of the paper. Indexed from zero, like gamma_ and R_. std::vector one_best_confidences_; // vector of confidences for the 1-best output (which could be // the MAP output if opts_.decode_mbr == false, or the MBR output otherwise). // Indexed by the same index as one_best_times_. struct GammaCompare{ // should be like operator <. But we want reverse order // on the 2nd element (posterior), so it'll be like operator // > that looks first at the posterior. bool operator () (const std::pair &a, const std::pair &b) const { if (a.second > b.second) return true; else if (a.second < b.second) return false; else return a.first > b.first; } }; }; } // namespace kaldi #endif // KALDI_LAT_SAUSAGES_H_