// nnet3/discriminative-supervision.cc // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) // 2014-2015 Vimal Manohar // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "nnet3/discriminative-supervision.h" #include "lat/lattice-functions.h" namespace kaldi { namespace discriminative { DiscriminativeSupervision::DiscriminativeSupervision( const DiscriminativeSupervision &other): weight(other.weight), num_sequences(other.num_sequences), frames_per_sequence(other.frames_per_sequence), num_ali(other.num_ali), den_lat(other.den_lat) { } void DiscriminativeSupervision::Swap(DiscriminativeSupervision *other) { std::swap(weight, other->weight); std::swap(num_sequences, other->num_sequences); std::swap(frames_per_sequence, other->frames_per_sequence); std::swap(num_ali, other->num_ali); std::swap(den_lat, other->den_lat); } bool DiscriminativeSupervision::operator == ( const DiscriminativeSupervision &other) const { return ( weight == other.weight && num_sequences == other.num_sequences && frames_per_sequence == other.frames_per_sequence && num_ali == other.num_ali && fst::Equal(den_lat, other.den_lat) ); } void DiscriminativeSupervision::Write(std::ostream &os, bool binary) const { WriteToken(os, binary, ""); WriteToken(os, binary, ""); WriteBasicType(os, binary, weight); WriteToken(os, binary, ""); WriteBasicType(os, binary, num_sequences); WriteToken(os, binary, ""); WriteBasicType(os, binary, frames_per_sequence); KALDI_ASSERT(frames_per_sequence > 0 && num_sequences > 0); WriteToken(os, binary, ""); WriteIntegerVector(os, binary, num_ali); WriteToken(os, binary, ""); if (!WriteLattice(os, binary, den_lat)) { // We can't return error status from this function so we // throw an exception. KALDI_ERR << "Error writing denominator lattice to stream"; } WriteToken(os, binary, ""); } void DiscriminativeSupervision::Read(std::istream &is, bool binary) { ExpectToken(is, binary, ""); ExpectToken(is, binary, ""); ReadBasicType(is, binary, &weight); ExpectToken(is, binary, ""); ReadBasicType(is, binary, &num_sequences); ExpectToken(is, binary, ""); ReadBasicType(is, binary, &frames_per_sequence); KALDI_ASSERT(frames_per_sequence > 0 && num_sequences > 0); ExpectToken(is, binary, ""); ReadIntegerVector(is, binary, &num_ali); ExpectToken(is, binary, ""); { Lattice *lat = NULL; if (!ReadLattice(is, binary, &lat) || lat == NULL) { // We can't return error status from this function so we // throw an exception. KALDI_ERR << "Error reading Lattice from stream"; } den_lat = *lat; delete lat; TopSort(&den_lat); } ExpectToken(is, binary, ""); } bool DiscriminativeSupervision::Initialize(const std::vector &num_ali, const Lattice &den_lat, BaseFloat weight) { if (num_ali.size() == 0) return false; if (den_lat.NumStates() == 0) return false; this->weight = weight; this->num_sequences = 1; this->frames_per_sequence = num_ali.size(); this->num_ali = num_ali; this->den_lat = den_lat; KALDI_ASSERT(TopSort(&(this->den_lat))); // Checks if num frames in alignment matches lattice Check(); return true; } void DiscriminativeSupervision::Check() const { int32 num_frames_subsampled = num_ali.size(); KALDI_ASSERT(num_frames_subsampled == num_sequences * frames_per_sequence); { std::vector state_times; int32 max_time = LatticeStateTimes(den_lat, &state_times); KALDI_ASSERT(max_time == num_frames_subsampled); } } DiscriminativeSupervisionSplitter::DiscriminativeSupervisionSplitter( const SplitDiscriminativeSupervisionOptions &config, const TransitionModel &tmodel, const DiscriminativeSupervision &supervision): config_(config), tmodel_(tmodel), supervision_(supervision) { if (supervision_.num_sequences != 1) { KALDI_WARN << "Splitting already-reattached sequence (only expected in " << "testing code)"; } KALDI_ASSERT(supervision_.num_sequences == 1); // For now, don't allow splitting already merged examples den_lat_ = supervision_.den_lat; PrepareLattice(&den_lat_, &den_lat_scores_); int32 num_states = den_lat_.NumStates(), num_frames = supervision_.frames_per_sequence * supervision_.num_sequences; KALDI_ASSERT(num_states > 0); int32 start_state = den_lat_.Start(); // Lattice should be top-sorted and connected, so start-state must be 0. KALDI_ASSERT(start_state == 0 && "Expecting start-state to be 0"); KALDI_ASSERT(num_states == den_lat_scores_.state_times.size()); KALDI_ASSERT(den_lat_scores_.state_times[start_state] == 0); KALDI_ASSERT(den_lat_scores_.state_times.back() == num_frames); } // Make sure that for any given pdf-id and any given frame, the den-lat has // only one transition-id mapping to that pdf-id, on the same frame. // It helps us to more completely minimize the lattice. Note: we // can't do this if the criterion is MPFE, because in that case the // objective function will be affected by the phone-identities being // different even if the pdf-ids are the same. void DiscriminativeSupervisionSplitter::CollapseTransitionIds( const std::vector &state_times, Lattice *lat) const { typedef Lattice::StateId StateId; typedef Lattice::Arc Arc; int32 num_frames = state_times.back(); // TODO: Check if this is always true StateId num_states = lat->NumStates(); std::vector > pdf_to_tid(num_frames); for (StateId s = 0; s < num_states; s++) { int32 t = state_times[s]; for (fst::MutableArcIterator aiter(lat, s); !aiter.Done(); aiter.Next()) { KALDI_ASSERT(t >= 0 && t < num_frames); Arc arc = aiter.Value(); KALDI_ASSERT(arc.ilabel != 0 && arc.ilabel == arc.olabel); int32 pdf = tmodel_.TransitionIdToPdf(arc.ilabel); if (pdf_to_tid[t].count(pdf) != 0) { arc.ilabel = arc.olabel = pdf_to_tid[t][pdf]; aiter.SetValue(arc); } else { pdf_to_tid[t][pdf] = arc.ilabel; } } } } void DiscriminativeSupervisionSplitter::LatticeInfo::Check() const { // Check if all the vectors are of size num_states KALDI_ASSERT(state_times.size() == alpha.size() && state_times.size() == beta.size()); // Check that the states are ordered in increasing order of state_times. // This must be true since the states are in breadth-first search order. KALDI_ASSERT(IsSorted(state_times)); } void DiscriminativeSupervisionSplitter::GetFrameRange(int32 begin_frame, int32 num_frames, bool normalize, DiscriminativeSupervision *out_supervision) const { int32 end_frame = begin_frame + num_frames; // Note: end_frame is not included in the range of frames that the // output supervision object covers; it's one past the end. KALDI_ASSERT(num_frames > 0 && begin_frame >= 0 && begin_frame + num_frames <= supervision_.num_sequences * supervision_.frames_per_sequence); CreateRangeLattice(den_lat_, den_lat_scores_, begin_frame, end_frame, normalize, &(out_supervision->den_lat)); out_supervision->num_ali.clear(); std::copy(supervision_.num_ali.begin() + begin_frame, supervision_.num_ali.begin() + end_frame, std::back_inserter(out_supervision->num_ali)); out_supervision->num_sequences = 1; out_supervision->weight = supervision_.weight; out_supervision->frames_per_sequence = num_frames; out_supervision->Check(); } void DiscriminativeSupervisionSplitter::CreateRangeLattice( const Lattice &in_lat, const LatticeInfo &scores, int32 begin_frame, int32 end_frame, bool normalize, Lattice *out_lat) const { typedef Lattice::StateId StateId; const std::vector &state_times = scores.state_times; // Some checks to ensure the lattice and scores are prepared properly KALDI_ASSERT(state_times.size() == in_lat.NumStates()); if (!in_lat.Properties(fst::kTopSorted, true)) KALDI_ERR << "Input lattice must be topologically sorted."; std::vector::const_iterator begin_iter = std::lower_bound(state_times.begin(), state_times.end(), begin_frame), end_iter = std::lower_bound(begin_iter, state_times.end(), end_frame); KALDI_ASSERT(*begin_iter == begin_frame && (begin_iter == state_times.begin() || begin_iter[-1] < begin_frame)); // even if end_frame == supervision_.num_frames, there should be a state with // that frame index. KALDI_ASSERT(end_iter[-1] < end_frame && (end_iter < state_times.end() || *end_iter == end_frame)); StateId begin_state = begin_iter - state_times.begin(), end_state = end_iter - state_times.begin(); KALDI_ASSERT(end_state > begin_state); out_lat->DeleteStates(); out_lat->ReserveStates(end_state - begin_state + 2); // Add special start state StateId start_state = out_lat->AddState(); out_lat->SetStart(start_state); for (StateId i = begin_state; i < end_state; i++) out_lat->AddState(); // Add the special final-state. StateId final_state = out_lat->AddState(); out_lat->SetFinal(final_state, LatticeWeight::One()); for (StateId state = begin_state; state < end_state; state++) { StateId output_state = state - begin_state + 1; if (state_times[state] == begin_frame) { // we'd like to make this an initial state, but OpenFst doesn't allow // multiple initial states. Instead we add an epsilon transition to it // from our actual initial state. The weight on this // transition is the forward probability of the said 'initial state' LatticeWeight weight = LatticeWeight::One(); weight.SetValue1((normalize ? scores.beta[0] : 0.0) - scores.alpha[state]); // Add negative of the forward log-probability to the graph cost score, // since the acoustic scores would be changed later. // Assuming that the lattice is scaled with appropriate acoustic // scale. // We additionally normalize using the total lattice score. Since the // same score is added as normalizer to all the paths in the lattice, // the relative probabilities of the paths in the lattice is not affected. // Note: Doing a forward-backward on this split must result in a total // score of 0 because of the normalization. out_lat->AddArc(start_state, LatticeArc(0, 0, weight, output_state)); } else { KALDI_ASSERT(scores.state_times[state] < end_frame); } for (fst::ArcIterator aiter(in_lat, state); !aiter.Done(); aiter.Next()) { const LatticeArc &arc = aiter.Value(); StateId nextstate = arc.nextstate; if (nextstate >= end_state) { // A transition to any state outside the range becomes a transition to // our special final-state. // The weight is just the negative of the backward log-probability + // the arc cost. We again normalize with the total lattice score. LatticeWeight weight; //KALDI_ASSERT(scores.beta[state] < 0); weight.SetValue1(arc.weight.Value1() - scores.beta[nextstate]); weight.SetValue2(arc.weight.Value2()); // Add negative of the backward log-probability to the LM score, since // the acoustic scores would be changed later. // Note: We don't normalize here because that is already done with the // initial cost. out_lat->AddArc(output_state, LatticeArc(arc.ilabel, arc.olabel, weight, final_state)); } else { StateId output_nextstate = nextstate - begin_state + 1; out_lat->AddArc(output_state, LatticeArc(arc.ilabel, arc.olabel, arc.weight, output_nextstate)); } } } // Get rid of the word labels and put the // transition-ids on both sides. fst::Project(out_lat, fst::PROJECT_INPUT); fst::RmEpsilon(out_lat); if (config_.collapse_transition_ids) CollapseTransitionIds(state_times, out_lat); if (config_.determinize) { if (!config_.minimize) { Lattice tmp_lat; fst::Determinize(*out_lat, &tmp_lat); std::swap(*out_lat, tmp_lat); } else { Lattice tmp_lat; fst::Reverse(*out_lat, &tmp_lat); fst::Determinize(tmp_lat, out_lat); fst::Reverse(*out_lat, &tmp_lat); fst::Determinize(tmp_lat, out_lat); fst::RmEpsilon(out_lat); } } fst::TopSort(out_lat); std::vector state_times_tmp; KALDI_ASSERT(LatticeStateTimes(*out_lat, &state_times_tmp) == end_frame - begin_frame); // Remove the acoustic scale that was previously added if (config_.acoustic_scale != 1.0) { fst::ScaleLattice(fst::AcousticLatticeScale( 1 / config_.acoustic_scale), out_lat); } } void DiscriminativeSupervisionSplitter::PrepareLattice( Lattice *lat, LatticeInfo *scores) const { // Scale the lattice to appropriate acoustic scale. It is important to // ensure this is equal to the acoustic scale used while training. This is // because, on splitting lattices, the initial and final costs are added // into the graph cost. KALDI_ASSERT(config_.acoustic_scale != 0.0); if (config_.acoustic_scale != 1.0) fst::ScaleLattice(fst::AcousticLatticeScale( config_.acoustic_scale), lat); LatticeStateTimes(*lat, &(scores->state_times)); int32 num_states = lat->NumStates(); std::vector > state_time_indexes(num_states); for (int32 s = 0; s < num_states; s++) { state_time_indexes[s] = std::make_pair(scores->state_times[s], s); } // Order the states based on the state times. This is stronger than just // topological sort. This is required by the lattice splitting code. std::sort(state_time_indexes.begin(), state_time_indexes.end()); std::vector state_order(num_states); for (int32 s = 0; s < num_states; s++) { state_order[state_time_indexes[s].second] = s; } fst::StateSort(lat, state_order); ComputeLatticeScores(*lat, scores); } void DiscriminativeSupervisionSplitter::ComputeLatticeScores(const Lattice &lat, LatticeInfo *scores) const { LatticeStateTimes(lat, &(scores->state_times)); ComputeLatticeAlphasAndBetas(lat, false, &(scores->alpha), &(scores->beta)); scores->Check(); // This check will fail if the lattice is not breadth-first search sorted } void MergeSupervision(const std::vector &input, DiscriminativeSupervision *output_supervision) { KALDI_ASSERT(!input.empty()); int32 num_inputs = input.size(); if (num_inputs == 1) { *output_supervision = *(input[0]); return; } *output_supervision = *(input[num_inputs-1]); for (int32 i = num_inputs - 2; i >= 0; i--) { const DiscriminativeSupervision &src = *(input[i]); KALDI_ASSERT(src.num_sequences == 1); if (output_supervision->weight == src.weight && output_supervision->frames_per_sequence == src.frames_per_sequence) { // Combine with current output // append src.den_lat to output_supervision->den_lat. fst::Concat(src.den_lat, &output_supervision->den_lat); output_supervision->num_ali.insert( output_supervision->num_ali.begin(), src.num_ali.begin(), src.num_ali.end()); output_supervision->num_sequences++; } else { KALDI_ERR << "Mismatch weight or frames_per_sequence between inputs"; } } DiscriminativeSupervision &out_sup = *output_supervision; fst::TopSort(&(out_sup.den_lat)); out_sup.Check(); } } // namespace discriminative } // namespace kaldi