// lat/phone-align-lattice.cc // Copyright 2012-2013 Microsoft Corporation // Johns Hopkins University (Author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "lat/phone-align-lattice.h" #include "hmm/transition-model.h" #include "util/stl-utils.h" namespace kaldi { class LatticePhoneAligner { public: typedef CompactLatticeArc::StateId StateId; typedef CompactLatticeArc::Label Label; class ComputationState { /// The state of the computation in which, /// along a single path in the lattice, we work out the phone /// boundaries and output phone-aligned arcs. [These may or may not have /// words on them; the word symbols are not aligned with anything. public: /// Advance the computation state by adding the symbols and weights /// from this arc. Gets rid of the weight and puts it in "weight" which /// will be put on the output arc; this keeps the state-space small. void Advance(const CompactLatticeArc &arc, const PhoneAlignLatticeOptions &opts, LatticeWeight *weight) { const std::vector &string = arc.weight.String(); transition_ids_.insert(transition_ids_.end(), string.begin(), string.end()); if (arc.ilabel != 0 && !opts.replace_output_symbols) // note: arc.ilabel==arc.olabel (acceptor) word_labels_.push_back(arc.ilabel); *weight = Times(weight_, arc.weight.Weight()); weight_ = LatticeWeight::One(); } /// If it can output a whole phone, it will do so, will put it in arc_out, /// and return true; else it will return false. If it detects an error /// condition and *error = false, it will set *error to true and print /// a warning. In this case it will still output phone arcs, they will /// just be inaccurate. Of course once *error is set, something has gone /// wrong so don't trust the output too fully. /// Note: the "next_state" of the arc will not be set, you have to do that /// yourself. bool OutputPhoneArc(const TransitionModel &tmodel, const PhoneAlignLatticeOptions &opts, CompactLatticeArc *arc_out, bool *error); /// This will succeed (and output the arc) if we have >1 word in words_; /// the arc won't have any transition-ids on it. This is intended to fix /// a particular pathology where too many words were pending and we had /// blowup. bool OutputWordArc(const TransitionModel &tmodel, const PhoneAlignLatticeOptions &opts, CompactLatticeArc *arc_out, bool *error); bool IsEmpty() { return (transition_ids_.empty() && word_labels_.empty()); } /// FinalWeight() will return "weight" if both transition_ids /// and word_labels are empty, otherwise it will return /// Weight::Zero(). LatticeWeight FinalWeight() { return (IsEmpty() ? weight_ : LatticeWeight::Zero()); } /// This function may be called when you reach the end of /// the lattice and this structure hasn't voluntarily /// output words using "OutputArc". If IsEmpty() == false, /// then you can call this function and it will output /// an arc. The only /// non-error state in which this happens, is when a word /// (or silence) has ended, but we don't know that it's /// ended because we haven't seen the first transition-id /// from the next word. Otherwise (error state), the output /// will consist of partial words, and this will only /// happen for lattices that were somehow broken, i.e. /// had not reached the final state. void OutputArcForce(const TransitionModel &tmodel, const PhoneAlignLatticeOptions &opts, CompactLatticeArc *arc_out, bool *error); size_t Hash() const { VectorHasher vh; return vh(transition_ids_) + 90647 * vh(word_labels_); // 90647 is an arbitrary largish prime number. // We don't bother including the weight in the hash-- // we don't really expect duplicates with the same vectors // but different weights, and anyway, this is only an // efficiency issue. } // Just need an arbitrary complete order. bool operator == (const ComputationState &other) const { return (transition_ids_ == other.transition_ids_ && word_labels_ == other.word_labels_ && weight_ == other.weight_); } ComputationState(): weight_(LatticeWeight::One()) { } // initial state. ComputationState(const ComputationState &other): transition_ids_(other.transition_ids_), word_labels_(other.word_labels_), weight_(other.weight_) { } private: std::vector transition_ids_; std::vector word_labels_; LatticeWeight weight_; // contains two floats. }; struct Tuple { Tuple(StateId input_state, ComputationState comp_state): input_state(input_state), comp_state(comp_state) {} StateId input_state; ComputationState comp_state; }; struct TupleHash { size_t operator() (const Tuple &state) const { return state.input_state + 102763 * state.comp_state.Hash(); // 102763 is just an arbitrary prime number } }; struct TupleEqual { bool operator () (const Tuple &state1, const Tuple &state2) const { // treat this like operator == return (state1.input_state == state2.input_state && state1.comp_state == state2.comp_state); } }; typedef unordered_map MapType; StateId GetStateForTuple(const Tuple &tuple, bool add_to_queue) { MapType::iterator iter = map_.find(tuple); if (iter == map_.end()) { // not in map. StateId output_state = lat_out_->AddState(); map_[tuple] = output_state; if (add_to_queue) queue_.push_back(std::make_pair(tuple, output_state)); return output_state; } else { return iter->second; } } void ProcessFinal(Tuple tuple, StateId output_state) { // ProcessFinal is only called if the input_state has // final-prob of One(). [else it should be zero. This // is because we called CreateSuperFinal().] if (tuple.comp_state.IsEmpty()) { // computation state doesn't have // anything pending. std::vector empty_vec; CompactLatticeWeight cw(tuple.comp_state.FinalWeight(), empty_vec); lat_out_->SetFinal(output_state, Plus(lat_out_->Final(output_state), cw)); } else { // computation state has something pending, i.e. input or // output symbols that need to be flushed out. Note: OutputArc() would // have returned false or we wouldn't have been called, so we have to // force it out. CompactLatticeArc lat_arc; tuple.comp_state.OutputArcForce(tmodel_, opts_, &lat_arc, &error_); lat_arc.nextstate = GetStateForTuple(tuple, true); // true == add to queue. // The final-prob stuff will get called again from ProcessQueueElement(). // Note: because we did CreateSuperFinal(), this final-state on the input // lattice will have no output arcs (and unit final-prob), so there will be // no complications with processing the arcs from this state (there won't // be any). KALDI_ASSERT(output_state != lat_arc.nextstate); lat_out_->AddArc(output_state, lat_arc); } } void ProcessQueueElement() { KALDI_ASSERT(!queue_.empty()); Tuple tuple = queue_.back().first; StateId output_state = queue_.back().second; queue_.pop_back(); // First thing is-- we see whether the computation-state has something // pending that it wants to output. In this case we don't do // anything further. This is a chosen behavior similar to the // epsilon-sequencing rules encoded by the filters in // composition. CompactLatticeArc lat_arc; Tuple tuple2(tuple); // temp if (tuple.comp_state.OutputPhoneArc(tmodel_, opts_, &lat_arc, &error_) || tuple.comp_state.OutputWordArc(tmodel_, opts_, &lat_arc, &error_)) { // note: this function changes the tuple (when it returns true). lat_arc.nextstate = GetStateForTuple(tuple, true); // true == add to queue, // if not already present. KALDI_ASSERT(output_state != lat_arc.nextstate); lat_out_->AddArc(output_state, lat_arc); } else { // when there's nothing to output, we'll process arcs from the input-state. // note: it would in a sense be valid to do both (i.e. process the stuff // above, and also these), but this is a bit like the epsilon-sequencing // stuff in composition: we avoid duplicate arcs by doing it this way. if (lat_.Final(tuple.input_state) != CompactLatticeWeight::Zero()) { KALDI_ASSERT(lat_.Final(tuple.input_state) == CompactLatticeWeight::One()); // ... since we did CreateSuperFinal. ProcessFinal(tuple, output_state); } // Now process the arcs. Note: final-state shouldn't have any arcs. for(fst::ArcIterator aiter(lat_, tuple.input_state); !aiter.Done(); aiter.Next()) { const CompactLatticeArc &arc = aiter.Value(); Tuple next_tuple(tuple); LatticeWeight weight; next_tuple.comp_state.Advance(arc, opts_, &weight); next_tuple.input_state = arc.nextstate; StateId next_output_state = GetStateForTuple(next_tuple, true); // true == add to queue, // if not already present. // We add an epsilon arc here (as the input and output happens // separately)... the epsilons will get removed later. KALDI_ASSERT(next_output_state != output_state); lat_out_->AddArc(output_state, CompactLatticeArc(0, 0, CompactLatticeWeight(weight, std::vector()), next_output_state)); } } } LatticePhoneAligner(const CompactLattice &lat, const TransitionModel &tmodel, const PhoneAlignLatticeOptions &opts, CompactLattice *lat_out): lat_(lat), tmodel_(tmodel), opts_(opts), lat_out_(lat_out), error_(false) { fst::CreateSuperFinal(&lat_); // Creates a super-final state, so the // only final-probs are One(). } // Removes epsilons; also removes unreachable states... // not sure if these would exist if original was connected. // This also replaces the temporary symbols for the silence // and partial-words, with epsilons, if we wanted epsilons. void RemoveEpsilonsFromLattice() { RmEpsilon(lat_out_, true); // true = connect. } bool AlignLattice() { lat_out_->DeleteStates(); if (lat_.Start() == fst::kNoStateId) { KALDI_WARN << "Trying to word-align empty lattice."; return false; } ComputationState initial_comp_state; Tuple initial_tuple(lat_.Start(), initial_comp_state); StateId start_state = GetStateForTuple(initial_tuple, true); // True = add this to queue. lat_out_->SetStart(start_state); while (!queue_.empty()) ProcessQueueElement(); if (opts_.remove_epsilon) RemoveEpsilonsFromLattice(); return !error_; } CompactLattice lat_; const TransitionModel &tmodel_; const PhoneAlignLatticeOptions &opts_; CompactLattice *lat_out_; std::vector > queue_; MapType map_; // map from tuples to StateId. bool error_; }; bool LatticePhoneAligner::ComputationState::OutputPhoneArc( const TransitionModel &tmodel, const PhoneAlignLatticeOptions &opts, CompactLatticeArc *arc_out, bool *error) { if (transition_ids_.empty()) return false; int32 phone = tmodel.TransitionIdToPhone(transition_ids_[0]); // we assume the start of transition_ids_ is the start of the phone; // this is a precondition. size_t len = transition_ids_.size(), i; // Keep going till we reach a "final" transition-id; note, if // reorder==true, we have to go a bit further after this. for (i = 0; i < len; i++) { int32 tid = transition_ids_[i]; int32 this_phone = tmodel.TransitionIdToPhone(tid); if (this_phone != phone && ! *error) { // error condition: should have // reached final transition-id first. *error = true; KALDI_WARN << phone << " -> " << this_phone; KALDI_WARN << "Phone changed before final transition-id found " "[broken lattice or mismatched model or wrong --reorder option?]"; } if (tmodel.IsFinal(tid)) break; } if (i == len) return false; // fell off loop. i++; // go past the one for which IsFinal returned true. if (opts.reorder) // we have to consume the following self-loop transition-ids. while (i < len && tmodel.IsSelfLoop(transition_ids_[i])) i++; if (i == len) return false; // we don't know if it ends here... so can't output arc. // interpret i as the number of transition-ids to consume. std::vector tids_out(transition_ids_.begin(), transition_ids_.begin()+i); Label output_label = 0; if (!word_labels_.empty()) { output_label = word_labels_[0]; word_labels_.erase(word_labels_.begin(), word_labels_.begin()+1); } if (opts.replace_output_symbols) output_label = phone; *arc_out = CompactLatticeArc(output_label, output_label, CompactLatticeWeight(weight_, tids_out), fst::kNoStateId); transition_ids_.erase(transition_ids_.begin(), transition_ids_.begin()+i); weight_ = LatticeWeight::One(); // we just output the weight. return true; } bool LatticePhoneAligner::ComputationState::OutputWordArc( const TransitionModel &tmodel, const PhoneAlignLatticeOptions &opts, CompactLatticeArc *arc_out, bool *error) { // output a word but no phones. if (word_labels_.size() < 2) return false; int32 output_label = word_labels_[0]; word_labels_.erase(word_labels_.begin(), word_labels_.begin()+1); *arc_out = CompactLatticeArc(output_label, output_label, CompactLatticeWeight(weight_, std::vector()), fst::kNoStateId); weight_ = LatticeWeight::One(); // we just output the weight, so set it to one. return true; } void LatticePhoneAligner::ComputationState::OutputArcForce( const TransitionModel &tmodel, const PhoneAlignLatticeOptions &opts, CompactLatticeArc *arc_out, bool *error) { KALDI_ASSERT(!IsEmpty()); int32 phone = -1; // This value -1 will never be used, // although it might not be obvious from superficially checking // the code. IsEmpty() would be true if we had transition_ids_.empty() // and opts.replace_output_symbols, so we would already die by assertion; // in fact, this function would neve be called. if (!transition_ids_.empty()) { // Do some checking here. int32 tid = transition_ids_[0]; phone = tmodel.TransitionIdToPhone(tid); int32 num_final = 0; for (int32 i = 0; i < transition_ids_.size(); i++) { // A check. int32 this_tid = transition_ids_[i]; int32 this_phone = tmodel.TransitionIdToPhone(this_tid); bool is_final = tmodel.IsFinal(this_tid); // should be exactly one. if (is_final) num_final++; if (this_phone != phone && ! *error) { KALDI_WARN << "Mismatch in phone: error in lattice or mismatched transition model?"; *error = true; } } if (num_final != 1 && ! *error) { KALDI_WARN << "Problem phone-aligning lattice: saw " << num_final << " final-states in last phone in lattice (forced out?) " << "Producing partial lattice."; *error = true; } } Label output_label = 0; if (!word_labels_.empty()) { output_label = word_labels_[0]; word_labels_.erase(word_labels_.begin(), word_labels_.begin()+1); } if (opts.replace_output_symbols) output_label = phone; *arc_out = CompactLatticeArc(output_label, output_label, CompactLatticeWeight(weight_, transition_ids_), fst::kNoStateId); transition_ids_.clear(); weight_ = LatticeWeight::One(); // we just output the weight. } bool PhoneAlignLattice(const CompactLattice &lat, const TransitionModel &tmodel, const PhoneAlignLatticeOptions &opts, CompactLattice *lat_out) { LatticePhoneAligner aligner(lat, tmodel, opts, lat_out); return aligner.AlignLattice(); } } // namespace kaldi