// lat/sausages.cc // Copyright 2012 Johns Hopkins University (Author: Daniel Povey) // 2015 Guoguo Chen // 2019 Dogan Can // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "lat/sausages.h" #include "lat/lattice-functions.h" namespace kaldi { // this is Figure 6 in the paper. void MinimumBayesRisk::MbrDecode() { for (size_t counter = 0; ; counter++) { NormalizeEps(&R_); AccStats(); // writes to gamma_ double delta_Q = 0.0; // change in objective function. one_best_times_.clear(); one_best_confidences_.clear(); // Caution: q in the line below is (q-1) in the algorithm // in the paper; both R_ and gamma_ are indexed by q-1. for (size_t q = 0; q < R_.size(); q++) { if (opts_.decode_mbr) { // This loop updates R_ [indexed same as gamma_]. // gamma_[i] is sorted in reverse order so most likely one is first. const std::vector > &this_gamma = gamma_[q]; double old_gamma = 0, new_gamma = this_gamma[0].second; int32 rq = R_[q], rhat = this_gamma[0].first; // rq: old word, rhat: new. for (size_t j = 0; j < this_gamma.size(); j++) if (this_gamma[j].first == rq) old_gamma = this_gamma[j].second; delta_Q += (old_gamma - new_gamma); // will be 0 or negative; a bound on // change in error. if (rq != rhat) KALDI_VLOG(2) << "Changing word " << rq << " to " << rhat; R_[q] = rhat; } // build the outputs (time, confidences), if (R_[q] != 0 || opts_.print_silence) { // see which 'item' from the sausage-bin should we select, // (not necessarily the 1st one when MBR decoding disabled) int32 s = 0; for (int32 j=0; j 1 && one_best_times_[i-2].second > one_best_times_[i-1].first) { // It's quite possible for this to happen, but it seems like it would // have a bad effect on the downstream processing, so we fix it here. // We resolve overlaps by redistributing the available time interval. BaseFloat prev_right = i > 2 ? one_best_times_[i-3].second : 0.0; BaseFloat left = std::max(prev_right, std::min(one_best_times_[i-2].first, one_best_times_[i-1].first)); BaseFloat right = std::max(one_best_times_[i-2].second, one_best_times_[i-1].second); BaseFloat first_dur = one_best_times_[i-2].second - one_best_times_[i-2].first; BaseFloat second_dur = one_best_times_[i-1].second - one_best_times_[i-1].first; BaseFloat mid = first_dur > 0 ? left + (right - left) * first_dur / (first_dur + second_dur) : left; one_best_times_[i-2].first = left; one_best_times_[i-2].second = one_best_times_[i-1].first = mid; one_best_times_[i-1].second = right; } BaseFloat confidence = 0.0; for (int32 j = 0; j < gamma_[q].size(); j++) { if (gamma_[q][j].first == R_[q]) { confidence = gamma_[q][j].second; break; } } one_best_confidences_.push_back(confidence); } } KALDI_VLOG(2) << "Iter = " << counter << ", delta-Q = " << delta_Q; if (delta_Q == 0) break; if (counter > 100) { KALDI_WARN << "Iterating too many times in MbrDecode; stopping."; break; } } if (!opts_.print_silence) RemoveEps(&R_); } struct Int32IsZero { bool operator() (int32 i) { return (i == 0); } }; // static void MinimumBayesRisk::RemoveEps(std::vector *vec) { Int32IsZero pred; vec->erase(std::remove_if (vec->begin(), vec->end(), pred), vec->end()); } // static void MinimumBayesRisk::NormalizeEps(std::vector *vec) { RemoveEps(vec); vec->resize(1 + vec->size() * 2); int32 s = vec->size(); for (int32 i = s/2 - 1; i >= 0; i--) { (*vec)[i*2 + 1] = (*vec)[i]; (*vec)[i*2 + 2] = 0; } (*vec)[0] = 0; } double MinimumBayesRisk::EditDistance(int32 N, int32 Q, Vector &alpha, Matrix &alpha_dash, Vector &alpha_dash_arc) { alpha(1) = 0.0; // = log(1). Line 5. alpha_dash(1, 0) = 0.0; // Line 5. for (int32 q = 1; q <= Q; q++) alpha_dash(1, q) = alpha_dash(1, q-1) + l(0, r(q)); // Line 7. for (int32 n = 2; n <= N; n++) { double alpha_n = kLogZeroDouble; for (size_t i = 0; i < pre_[n].size(); i++) { const Arc &arc = arcs_[pre_[n][i]]; alpha_n = LogAdd(alpha_n, alpha(arc.start_node) + arc.loglike); } alpha(n) = alpha_n; // Line 10. // Line 11 omitted: matrix was initialized to zero. for (size_t i = 0; i < pre_[n].size(); i++) { const Arc &arc = arcs_[pre_[n][i]]; int32 s_a = arc.start_node, w_a = arc.word; BaseFloat p_a = arc.loglike; for (int32 q = 0; q <= Q; q++) { if (q == 0) { alpha_dash_arc(q) = // line 15. alpha_dash(s_a, q) + l(w_a, 0, true); } else { // a1,a2,a3 are the 3 parts of min expression of line 17. int32 r_q = r(q); double a1 = alpha_dash(s_a, q-1) + l(w_a, r_q), a2 = alpha_dash(s_a, q) + l(w_a, 0, true), a3 = alpha_dash_arc(q-1) + l(0, r_q); alpha_dash_arc(q) = std::min(a1, std::min(a2, a3)); } // line 19: alpha_dash(n, q) += Exp(alpha(s_a) + p_a - alpha(n)) * alpha_dash_arc(q); } } } return alpha_dash(N, Q); // line 23. } // Figure 5 in the paper. void MinimumBayesRisk::AccStats() { using std::map; int32 N = static_cast(pre_.size()) - 1, Q = static_cast(R_.size()); Vector alpha(N+1); // index (1...N) Matrix alpha_dash(N+1, Q+1); // index (1...N, 0...Q) Vector alpha_dash_arc(Q+1); // index 0...Q Matrix beta_dash(N+1, Q+1); // index (1...N, 0...Q) Vector beta_dash_arc(Q+1); // index 0...Q std::vector b_arc(Q+1); // integer in {1,2,3}; index 1...Q std::vector > gamma(Q+1); // temp. form of gamma. // index 1...Q [word] -> occ. // The tau maps below are the sums over arcs with the same word label // of the tau_b and tau_e timing quantities mentioned in Appendix C of // the paper... we are using these to get averaged times for both the // the sausage bins and the 1-best output. std::vector > tau_b(Q+1), tau_e(Q+1); double Ltmp = EditDistance(N, Q, alpha, alpha_dash, alpha_dash_arc); if (L_ != 0 && Ltmp > L_) { // L_ != 0 is to rule out 1st iter. KALDI_WARN << "Edit distance increased: " << Ltmp << " > " << L_; } L_ = Ltmp; KALDI_VLOG(2) << "L = " << L_; // omit line 10: zero when initialized. beta_dash(N, Q) = 1.0; // Line 11. for (int32 n = N; n >= 2; n--) { for (size_t i = 0; i < pre_[n].size(); i++) { const Arc &arc = arcs_[pre_[n][i]]; int32 s_a = arc.start_node, w_a = arc.word; BaseFloat p_a = arc.loglike; alpha_dash_arc(0) = alpha_dash(s_a, 0) + l(w_a, 0, true); // line 14. for (int32 q = 1; q <= Q; q++) { // this loop == lines 15-18. int32 r_q = r(q); double a1 = alpha_dash(s_a, q-1) + l(w_a, r_q), a2 = alpha_dash(s_a, q) + l(w_a, 0, true), a3 = alpha_dash_arc(q-1) + l(0, r_q); if (a1 <= a2) { if (a1 <= a3) { b_arc[q] = 1; alpha_dash_arc(q) = a1; } else { b_arc[q] = 3; alpha_dash_arc(q) = a3; } } else { if (a2 <= a3) { b_arc[q] = 2; alpha_dash_arc(q) = a2; } else { b_arc[q] = 3; alpha_dash_arc(q) = a3; } } } beta_dash_arc.SetZero(); // line 19. for (int32 q = Q; q >= 1; q--) { // line 21: beta_dash_arc(q) += Exp(alpha(s_a) + p_a - alpha(n)) * beta_dash(n, q); switch (static_cast(b_arc[q])) { // lines 22 and 23: case 1: beta_dash(s_a, q-1) += beta_dash_arc(q); // next: gamma(q, w(a)) += beta_dash_arc(q) AddToMap(w_a, beta_dash_arc(q), &(gamma[q])); // next: accumulating times, see decl for tau_b,tau_e AddToMap(w_a, state_times_[s_a] * beta_dash_arc(q), &(tau_b[q])); AddToMap(w_a, state_times_[n] * beta_dash_arc(q), &(tau_e[q])); break; case 2: beta_dash(s_a, q) += beta_dash_arc(q); break; case 3: beta_dash_arc(q-1) += beta_dash_arc(q); // next: gamma(q, epsilon) += beta_dash_arc(q) AddToMap(0, beta_dash_arc(q), &(gamma[q])); // next: accumulating times, see decl for tau_b,tau_e // WARNING: there was an error in Appendix C. If we followed // the instructions there the next line would say state_times_[sa], but // it would be wrong. I will try to publish an erratum. AddToMap(0, state_times_[n] * beta_dash_arc(q), &(tau_b[q])); AddToMap(0, state_times_[n] * beta_dash_arc(q), &(tau_e[q])); break; default: KALDI_ERR << "Invalid b_arc value"; // error in code. } } beta_dash_arc(0) += Exp(alpha(s_a) + p_a - alpha(n)) * beta_dash(n, 0); beta_dash(s_a, 0) += beta_dash_arc(0); // line 26. } } beta_dash_arc.SetZero(); // line 29. for (int32 q = Q; q >= 1; q--) { beta_dash_arc(q) += beta_dash(1, q); beta_dash_arc(q-1) += beta_dash_arc(q); AddToMap(0, beta_dash_arc(q), &(gamma[q])); // the statements below are actually redundant because // state_times_[1] is zero. AddToMap(0, state_times_[1] * beta_dash_arc(q), &(tau_b[q])); AddToMap(0, state_times_[1] * beta_dash_arc(q), &(tau_e[q])); } for (int32 q = 1; q <= Q; q++) { // a check (line 35) double sum = 0.0; for (map::iterator iter = gamma[q].begin(); iter != gamma[q].end(); ++iter) sum += iter->second; if (fabs(sum - 1.0) > 0.1) KALDI_WARN << "sum of gamma[" << q << ",s] is " << sum; } // The next part is where we take gamma, and convert // to the class member gamma_, which is using a different // data structure and indexed from zero, not one. gamma_.clear(); gamma_.resize(Q); for (int32 q = 1; q <= Q; q++) { for (map::iterator iter = gamma[q].begin(); iter != gamma[q].end(); ++iter) gamma_[q-1].push_back( std::make_pair(iter->first, static_cast(iter->second))); // sort gamma_[q-1] from largest to smallest posterior. GammaCompare comp; std::sort(gamma_[q-1].begin(), gamma_[q-1].end(), comp); } // We do the same conversion for the state times tau_b and tau_e: // they get turned into the times_ data member, which has zero-based // indexing. times_.clear(); times_.resize(Q); sausage_times_.clear(); sausage_times_.resize(Q); for (int32 q = 1; q <= Q; q++) { double t_b = 0.0, t_e = 0.0; for (std::vector>::iterator iter = gamma_[q-1].begin(); iter != gamma_[q-1].end(); ++iter) { double w_b = tau_b[q][iter->first], w_e = tau_e[q][iter->first]; if (w_b > w_e) KALDI_WARN << "Times out of order"; // this is quite bad. times_[q-1].push_back( std::make_pair(static_cast(w_b / iter->second), static_cast(w_e / iter->second))); t_b += w_b; t_e += w_e; } sausage_times_[q-1].first = t_b; sausage_times_[q-1].second = t_e; if (sausage_times_[q-1].first > sausage_times_[q-1].second) KALDI_WARN << "Times out of order"; // this is quite bad. if (q > 1 && sausage_times_[q-2].second > sausage_times_[q-1].first) { // We previously had a warning here, but now we'll just set both // those values to their average. It's quite possible for this // condition to happen, but it seems like it would have a bad effect // on the downstream processing, so we fix it. sausage_times_[q-2].second = sausage_times_[q-1].first = 0.5 * (sausage_times_[q-2].second + sausage_times_[q-1].first); } } } void MinimumBayesRisk::PrepareLatticeAndInitStats(CompactLattice *clat) { KALDI_ASSERT(clat != NULL); CreateSuperFinal(clat); // Add super-final state to clat... this is // one of the requirements of the MBR algorithm, as mentioned in the // paper (i.e. just one final state). // Topologically sort the lattice, if not already sorted. kaldi::uint64 props = clat->Properties(fst::kFstProperties, false); if (!(props & fst::kTopSorted)) { if (fst::TopSort(clat) == false) KALDI_ERR << "Cycles detected in lattice."; } CompactLatticeStateTimes(*clat, &state_times_); // work out times of // the states in clat state_times_.push_back(0); // we'll convert to 1-based numbering. for (size_t i = state_times_.size()-1; i > 0; i--) state_times_[i] = state_times_[i-1]; // Now we convert the information in "clat" into a special internal // format (pre_, post_ and arcs_) which allows us to access the // arcs preceding any given state. // Note: in our internal format the states will be numbered from 1, // which involves adding 1 to the OpenFst states. int32 N = clat->NumStates(); pre_.resize(N+1); // Careful: "Arc" is a class-member struct, not an OpenFst type of arc as one // would normally assume. for (int32 n = 1; n <= N; n++) { for (fst::ArcIterator aiter(*clat, n-1); !aiter.Done(); aiter.Next()) { const CompactLatticeArc &carc = aiter.Value(); Arc arc; // in our local format. arc.word = carc.ilabel; // == carc.olabel arc.start_node = n; arc.end_node = carc.nextstate + 1; // convert to 1-based. arc.loglike = - (carc.weight.Weight().Value1() + carc.weight.Weight().Value2()); // loglike: sum graph/LM and acoustic cost, and negate to // convert to loglikes. We assume acoustic scaling is already done. pre_[arc.end_node].push_back(arcs_.size()); // record index of this arc. arcs_.push_back(arc); } } } MinimumBayesRisk::MinimumBayesRisk(const CompactLattice &clat_in, MinimumBayesRiskOptions opts) : opts_(opts) { CompactLattice clat(clat_in); // copy. PrepareLatticeAndInitStats(&clat); // We don't need to look at clat.Start() or clat.Final(state): // we know clat.Start() == 0 since it's topologically sorted, // and clat.Final(state) is Zero() except for One() at the last- // numbered state, thanks to CreateSuperFinal and the topological // sorting. { // Now set R_ to one best in the FST. RemoveAlignmentsFromCompactLattice(&clat); // will be more efficient // in best-path if we do this. Lattice lat; ConvertLattice(clat, &lat); // convert from CompactLattice to Lattice. fst::VectorFst fst; ConvertLattice(lat, &fst); // convert from lattice to normal FST. fst::VectorFst fst_shortest_path; fst::ShortestPath(fst, &fst_shortest_path); // take shortest path of FST. std::vector alignment, words; fst::TropicalWeight weight; GetLinearSymbolSequence(fst_shortest_path, &alignment, &words, &weight); KALDI_ASSERT(alignment.empty()); // we removed the alignment. R_ = words; L_ = 0.0; // Set current edit-distance to 0 [just so we know // when we're on the 1st iter.] } MbrDecode(); } MinimumBayesRisk::MinimumBayesRisk(const CompactLattice &clat_in, const std::vector &words, MinimumBayesRiskOptions opts) : opts_(opts) { CompactLattice clat(clat_in); // copy. PrepareLatticeAndInitStats(&clat); R_ = words; L_ = 0.0; MbrDecode(); } MinimumBayesRisk::MinimumBayesRisk(const CompactLattice &clat_in, const std::vector &words, const std::vector > ×, MinimumBayesRiskOptions opts) : opts_(opts) { CompactLattice clat(clat_in); // copy. PrepareLatticeAndInitStats(&clat); R_ = words; sausage_times_ = times; L_ = 0.0; MbrDecode(); } } // namespace kaldi