push-lattice.cc 11.6 KB
// lat/push-lattice.cc

// Copyright 2009-2011  Saarland University (Author: Arnab Ghoshal)
//           2012-2013  Johns Hopkins University (Author: Daniel Povey);  Chao Weng;
//                      Bagher BabaAli
//                2014  Guoguo Chen

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.


#include "lat/push-lattice.h"
#include "hmm/transition-model.h"
#include "util/stl-utils.h"

namespace fst {


template<class Weight, class IntType> class CompactLatticePusher {  
 public:
  typedef CompactLatticeWeightTpl<Weight, IntType> CompactWeight;
  typedef ArcTpl<CompactWeight> CompactArc;
  typedef typename CompactArc::StateId StateId;
  
  CompactLatticePusher(MutableFst<CompactArc> *clat): clat_(clat) { }
  bool Push() {
    if (clat_->Properties(kTopSorted, true) == 0) {
      if (!TopSort(clat_)) {
        KALDI_WARN << "Topological sorting of state-level lattice failed "
            "(probably your lexicon has empty words or your LM has epsilon cycles; this "
            " is a bad idea.)";
        return false;
      }
    }
    ComputeShifts();
    ApplyShifts();
    return true;
  }

  // Gets the string of length [end - begin], starting at this
  // state and taking arc "arc_idx" (and thereafter an arbitrary sequence).
  // Note: here, arc_idx == -1 means take an arbitrary path.
  static void GetString(const ExpandedFst<CompactArc> &clat,
                        StateId state,
                        size_t arc_idx,
                        typename std::vector<IntType>::iterator begin,
                        typename std::vector<IntType>::iterator end) {
    CompactWeight final = clat.Final(state);
    size_t len = end - begin;
    KALDI_ASSERT(len >= 0);
    if (len == 0) return;
    if (arc_idx == -1 && final != CompactWeight::Zero()) {
      const std::vector<IntType> &string = final.String();
      KALDI_ASSERT(string.size() >= len &&
                   "Either code error, or paths in lattice have inconsistent lengths");
      std::copy(string.begin(), string.begin() + len, begin);
      return;
    }

    ArcIterator<ExpandedFst<CompactArc> > aiter(clat, state);
    if (arc_idx != -1)
      aiter.Seek(arc_idx);
    KALDI_ASSERT(!aiter.Done() &&
                 "Either code error, or paths in lattice are inconsistent in length.");

    const CompactArc &arc = aiter.Value();
    size_t arc_len = arc.weight.String().size();
    if (arc_len >= len) {
      std::copy(arc.weight.String().begin(), arc.weight.String().begin() + len, begin);
    } else {
      std::copy(arc.weight.String().begin(), arc.weight.String().end(), begin);
      // Recurse.
      GetString(clat, arc.nextstate, -1, begin + arc_len, end);
    }
  }

  void CheckForConflict(const CompactWeight &final,
                        StateId state,
                        int32 *shift) {
    if (shift == NULL) return;
    // At input, "shift" has the maximum value that we could shift back assuming
    // there is no conflict between the values of the strings.  We need to check
    // if there is conflict, and if so, reduce the "shift".
    bool is_final = (final != CompactWeight::Zero());
    size_t num_arcs = clat_->NumArcs(state);
    if (num_arcs + (is_final ? 1 : 0) > 1 && *shift > 0) {
      // There is potential for conflict between string values, because >1
      // [arc or final-prob].  Find the longest shift up to and including the
      // current shift, that gives no conflict.

      std::vector<IntType> string(*shift), compare_string(*shift);
      size_t arc;
      if (is_final) {
        KALDI_ASSERT(final.String().size() >= *shift);
        std::copy(final.String().begin(), final.String().begin() + *shift,
                  string.begin());
        arc = 0;
      } else {
        // set "string" to string if we take 1st arc.
        GetString(*clat_, state, 0, string.begin(), string.end());
        arc = 1;
      }
      for (; arc < num_arcs; arc++) { // for the other arcs..
        GetString(*clat_, state, arc,
                  compare_string.begin(), compare_string.end());
        std::pair<typename std::vector<IntType>::iterator,
            typename std::vector<IntType>::iterator> pr =
                std::mismatch(string.begin(), string.end(),
                              compare_string.begin());
        if (pr.first != string.end()) { // There was a mismatch.  Reduce the shift
          // to a value where they will match.
          *shift = pr.first - string.begin();
          string.resize(*shift);
          compare_string.resize(*shift);
        }
      }
    }
  }

  void ComputeShifts() {
    StateId num_states = clat_->NumStates();
    shift_vec_.resize(num_states, 0);
    
    // The for loop will only work if StateId is signed, so assert this.
    KALDI_COMPILE_TIME_ASSERT(static_cast<StateId>(-1) < static_cast<StateId>(0));
    // We rely on the topological sorting, so clat_->Start() should be zero or
    // at least any preceding states should be non-accessible.  We leave the
    // shift at zero for the start state because we can't shift to before that.
    for (StateId state = num_states - 1; state > clat_->Start(); state--) {
      size_t num_arcs = clat_->NumArcs(state);
      CompactWeight final = clat_->Final(state);
      if (num_arcs == 0) {
        // we can shift back by the number of transition-ids on the
        // final-prob, if any.
        shift_vec_[state] = final.String().size();
      } else { // We have arcs ...
        int32 shift = std::numeric_limits<int32>::max();
        size_t num_arcs = 0;
        bool is_final = (final != CompactWeight::Zero());
        if (is_final)
          shift = std::min(shift, static_cast<int32>(final.String().size()));
        for (ArcIterator<MutableFst<CompactArc> > aiter(*clat_, state);
             !aiter.Done(); aiter.Next(), num_arcs++) {
          const CompactArc &arc (aiter.Value());
          shift = std::min(shift, shift_vec_[arc.nextstate] +
                           static_cast<int32>(arc.weight.String().size()));
        }
        CheckForConflict(final, state, &shift);
        shift_vec_[state] = shift;
      }
    }
  }

  void ApplyShifts() {
    StateId num_states = clat_->NumStates();
    for (StateId state = 0; state < num_states; state++) {
      int32 shift = shift_vec_[state];
      std::vector<IntType> string;
      for (MutableArcIterator<MutableFst<CompactArc> > aiter(clat_, state);
           !aiter.Done(); aiter.Next()) {
        CompactArc arc(aiter.Value());
        KALDI_ASSERT(arc.nextstate > state && "Cyclic lattice");

        string = arc.weight.String();
        size_t orig_len = string.size(), next_shift = shift_vec_[arc.nextstate];
        // extend "string" by next_shift.
        string.resize(string.size() + next_shift);
        // The next command sets the last "next_shift" elements of 'string' to
        // the string starting from arc.nextstate (taking an arbitrary path).
        GetString(*clat_, arc.nextstate, -1,
                  string.begin() + orig_len, string.end());
        // Remove the first "shift" elements of this string and set the
        // arc-weight string to this.
        arc.weight.SetString(std::vector<IntType>(string.begin() + shift,
                                                string.end()));
        aiter.SetValue(arc);
      }
      
      CompactWeight final = clat_->Final(state);
      if (final != CompactWeight::Zero()) {
        // Erase first "shift" elements of final-prob.
        final.SetString(std::vector<IntType>(final.String().begin() + shift,
                                           final.String().end()));
        clat_->SetFinal(state, final);
      }
    }
  }

 private:
  MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *clat_;

  // For each state s, shift_vec_[s] >= 0  is how much we will shift the
  // transition-ids back at this state.
  std::vector<int32> shift_vec_;
};

template<class Weight, class IntType>
bool PushCompactLatticeStrings(
    MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *clat) {
  CompactLatticePusher<Weight, IntType> pusher(clat);
  return pusher.Push();
}

template<class Weight, class IntType>
bool PushCompactLatticeWeights(
   MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *clat) {
  if (clat->Properties(kTopSorted, true) == 0) {
    if (!TopSort(clat)) {
      KALDI_WARN << "Topological sorting of state-level lattice failed "
          "(probably your lexicon has empty words or your LM has epsilon cycles; this "
          " is a bad idea.)";
      return false;
    }
  }
  typedef CompactLatticeWeightTpl<Weight, IntType> CompactWeight;
  typedef ArcTpl<CompactWeight> CompactArc;
  typedef typename CompactArc::StateId StateId;

  StateId num_states = clat->NumStates();
  if (num_states == 0) {
    KALDI_WARN << "Pushing weights of empty compact lattice";
    return true; // this is technically success because an empty
                 // lattice is already pushed.
  }
  std::vector<Weight> weight_to_end(num_states); // Note: LatticeWeight
                                                 // contains two floats.
  for (StateId s = num_states - 1; s >= 0; s--) {
    Weight this_weight_to_end = clat->Final(s).Weight();
    for (ArcIterator<MutableFst<CompactArc> > aiter(*clat, s);
         !aiter.Done(); aiter.Next()) {
      const CompactArc &arc = aiter.Value();
      KALDI_ASSERT(arc.nextstate > s && "Cyclic lattices not allowed.");
      this_weight_to_end = Plus(this_weight_to_end,
                                Times(aiter.Value().weight.Weight(),
                                      weight_to_end[arc.nextstate]));
    }
    if (this_weight_to_end == Weight::Zero()) {
      KALDI_WARN << "Lattice has non-coaccessible states.";
    }
    weight_to_end[s] = this_weight_to_end;
  }
  weight_to_end[0] = Weight::One(); // We leave the "leftover weight" on
                                    // the start state, which won't
                                    // necessarily end up summing to one.
  for (StateId s = 0; s < num_states; s++) {
    Weight this_weight_to_end = weight_to_end[s];
    if (this_weight_to_end == Weight::Zero())
      continue;
    for (MutableArcIterator<MutableFst<CompactArc> > aiter(clat, s);
         !aiter.Done(); aiter.Next()) {
      CompactArc arc = aiter.Value();
      Weight next_weight_to_end = weight_to_end[arc.nextstate];
      if (next_weight_to_end != Weight::Zero()) {
        arc.weight.SetWeight(Times(arc.weight.Weight(),
                                   Divide(next_weight_to_end,
                                          this_weight_to_end)));
        aiter.SetValue(arc);
      }
    }
    CompactWeight final_weight = clat->Final(s);
    if (final_weight != CompactWeight::Zero()) {
      final_weight.SetWeight(Divide(final_weight.Weight(), this_weight_to_end));
      clat->SetFinal(s, final_weight);
    }
  }
  
  return true;
}

// Instantiate for CompactLattice.
template
bool PushCompactLatticeStrings<kaldi::LatticeWeight, kaldi::int32>(
   MutableFst<kaldi::CompactLatticeArc> *clat);

template
bool PushCompactLatticeWeights<kaldi::LatticeWeight, kaldi::int32>(
   MutableFst<kaldi::CompactLatticeArc> *clat);

}  // namespace fst