factor-inl.h 11.4 KB
// fstext/factor-inl.h

// Copyright 2009-2011  Microsoft Corporation

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_FSTEXT_FACTOR_INL_H_
#define KALDI_FSTEXT_FACTOR_INL_H_

#include "util/stl-utils.h"
// Do not include this file directly.  It is included by factor.h.

namespace fst {

// GetStateProperties takes in an FST and a number "max_state" which is the
// highest numbered state in the FST (this could be fst.NumStates()-1 for an
// ExpandedFst, or derived from some kind of traversal).  It outputs a vector
// numbered from 0..max_state, of type FstStateProperties which is a bitmask
// with information about the states.

// GetStateProperties has not been tested directly (only implicitly via
// testing Factor).
template<class Arc>
void GetStateProperties(const Fst<Arc> &fst,
                        typename Arc::StateId max_state,
                        vector<StatePropertiesType> *props) {
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Weight Weight;
  assert(props != NULL);
  props->clear();
  if (fst.Start() < 0) return;  // Empty fst.
  props->resize(max_state+1, 0);
  assert(fst.Start() <= max_state);
  (*props)[fst.Start()] |= kStateInitial;
  for (StateId s = 0; s <= max_state; s++) {
    StatePropertiesType &s_info = (*props)[s];
    for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
      const Arc &arc = aiter.Value();
      if (arc.ilabel != 0) s_info |= kStateIlabelsOut;
      if (arc.olabel != 0) s_info |= kStateOlabelsOut;
      StateId nexts = arc.nextstate;
      assert(nexts <= max_state);  // or input was invalid.
      StatePropertiesType &nexts_info = (*props)[nexts];
      if (s_info&kStateArcsOut) s_info |= kStateMultipleArcsOut;
      s_info |= kStateArcsOut;
      if (nexts_info&kStateArcsIn) nexts_info |= kStateMultipleArcsIn;
      nexts_info |= kStateArcsIn;
    }
    if (fst.Final(s) != Weight::Zero())  s_info |= kStateFinal;
  }
}



template<class Arc, class I>
void Factor(const Fst<Arc> &fst, MutableFst<Arc> *ofst,
               vector<vector<I> > *symbols_out) {
  KALDI_ASSERT_IS_INTEGER_TYPE(I);
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;
  assert(symbols_out != NULL);
  ofst->DeleteStates();
  if (fst.Start() < 0) return;  // empty FST.
  vector<StateId> order;
  DfsOrderVisitor<Arc> dfs_order_visitor(&order);
  DfsVisit(fst, &dfs_order_visitor);
  assert(order.size() > 0);
  StateId max_state = *(std::max_element(order.begin(), order.end()));
  vector<StatePropertiesType> state_properties;
  GetStateProperties(fst, max_state, &state_properties);

  vector<bool> remove(max_state+1);  // if true, will remove this state.

  // Now identify states that will be removed (made the middle of a chain).
  // The basic rule is that if the FstStateProperties equals
  // (kStateArcsIn|kStateArcsOut) or (kStateArcsIn|kStateArcsOut|kStateIlabelsOut),
  // then it is in the middle of a chain.  This eliminates state with
  // multiple input or output arcs, final states, and states with arcs out
  // that have olabels [we assume these are pushed to the left, so occur on the
  // 1st arc of a chain.

  for (StateId i = 0; i <= max_state; i++)
    remove[i] = (state_properties[i] == (kStateArcsIn|kStateArcsOut)
                 || state_properties[i] == (kStateArcsIn|kStateArcsOut|kStateIlabelsOut));
  vector<StateId> state_mapping(max_state+1, kNoStateId);

  typedef unordered_map<vector<I>, Label, kaldi::VectorHasher<I> > SymbolMapType;
  SymbolMapType symbol_mapping;
  Label symbol_counter = 0;
  {
    vector<I> eps;
    symbol_mapping[eps] = symbol_counter++;
  }
  vector<I> this_sym;  // a temporary used inside the loop.
  for (size_t i = 0; i < order.size(); i++) {
    StateId state = order[i];
    if (!remove[state]) {  // Process this state...
      StateId &new_state = state_mapping[state];
      if (new_state == kNoStateId) new_state = ofst->AddState();
      for (ArcIterator<Fst<Arc> > aiter(fst, state); !aiter.Done(); aiter.Next()) {
        Arc arc = aiter.Value();
        if (arc.ilabel == 0) this_sym.clear();
        else {
          this_sym.resize(1);
          this_sym[0] = arc.ilabel;
        }
        while (remove[arc.nextstate]) {
          ArcIterator<Fst<Arc> > aiter2(fst, arc.nextstate);
          assert(!aiter2.Done());
          const Arc &nextarc = aiter2.Value();
          arc.weight = Times(arc.weight, nextarc.weight);
          assert(nextarc.olabel == 0);
          if (nextarc.ilabel != 0) this_sym.push_back(nextarc.ilabel);
          assert(static_cast<Label>(static_cast<I>(nextarc.ilabel))
                 == nextarc.ilabel); // check within integer range.
          arc.nextstate = nextarc.nextstate;
        }
        StateId &new_nextstate = state_mapping[arc.nextstate];
        if (new_nextstate == kNoStateId) new_nextstate = ofst->AddState();
        arc.nextstate = new_nextstate;
        if (symbol_mapping.count(this_sym) != 0) arc.ilabel = symbol_mapping[this_sym];
        else arc.ilabel = symbol_mapping[this_sym] = symbol_counter++;
        ofst->AddArc(new_state, arc);
      }
      if (fst.Final(state) != Weight::Zero())
        ofst->SetFinal(new_state, fst.Final(state));
    }
  }
  ofst->SetStart(state_mapping[fst.Start()]);

  // Now output the symbol sequences.
  symbols_out->resize(symbol_counter);
  for (typename SymbolMapType::const_iterator iter = symbol_mapping.begin();
      iter != symbol_mapping.end(); ++iter) {
    (*symbols_out)[iter->second] = iter->first;
  }
}

template<class Arc>
void Factor(const Fst<Arc> &fst, MutableFst<Arc> *ofst1,
            MutableFst<Arc> *ofst2) {
  typedef typename Arc::Label Label;
  vector<vector<Label> > symbols;
  Factor(fst, ofst2, &symbols);
  CreateFactorFst(symbols, ofst1);
}

template<class Arc, class I>
void ExpandInputSequences(const vector<vector<I> > &sequences,
                          MutableFst<Arc> *fst) {
  KALDI_ASSERT_IS_INTEGER_TYPE(I);
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;
  fst->SetInputSymbols(NULL);
  size_t size = sequences.size();
  if (sequences.size() > 0) assert(sequences[0].size() == 0);  // should be eps.
  StateId num_states_at_start = fst->NumStates();
  for (StateId s = 0; s < num_states_at_start; s++) {
    StateId num_arcs = fst->NumArcs(s);
    for (StateId aidx = 0; aidx < num_arcs; aidx++) {
      ArcIterator<MutableFst<Arc> > aiter(*fst, s);
      aiter.Seek(aidx);
      Arc arc = aiter.Value();

      Label ilabel = arc.ilabel;
      Label dest_state = arc.nextstate;
      if (ilabel != 0) {  // non-eps [nothing to do if eps]...
        assert(ilabel < static_cast<Label>(size));
        size_t len = sequences[ilabel].size();
        if (len <= 1) {
          if (len == 0) arc.ilabel = 0;
          else arc.ilabel = sequences[ilabel][0];
          MutableArcIterator<MutableFst<Arc> > mut_aiter(fst, s);
          mut_aiter.Seek(aidx);
          mut_aiter.SetValue(arc);
        } else {  // len>=2.  Must create new states...
          StateId curstate = -1;  // keep compiler happy: this value never used.
          for (size_t n = 0; n < len; n++) {  // adding/modifying "len" arcs.
            StateId nextstate;
            if (n < len-1) {
              nextstate = fst->AddState();
              assert(nextstate >= num_states_at_start);
            } else nextstate = dest_state;  // going back to original arc's
            // destination.
            if (n == 0) {
              arc.ilabel = sequences[ilabel][0];
              arc.nextstate = nextstate;
              MutableArcIterator<MutableFst<Arc> > mut_aiter(fst, s);
              mut_aiter.Seek(aidx);
              mut_aiter.SetValue(arc);
            } else {
              arc.ilabel = sequences[ilabel][n];
              arc.olabel = 0;
              arc.weight = Weight::One();
              arc.nextstate = nextstate;
              fst->AddArc(curstate, arc);
            }
            curstate = nextstate;
          }
        }
      }
    }
  }
}


template<class Arc, class I>
class RemoveSomeInputSymbolsMapper {
public:
  Arc operator ()(const Arc &arc_in) {
    Arc ans = arc_in;
    if (to_remove_set_.count(ans.ilabel) != 0) ans.ilabel = 0;  // remove this symbol
    return ans;
  }
  MapFinalAction FinalAction() { return MAP_NO_SUPERFINAL; }
  MapSymbolsAction InputSymbolsAction() { return MAP_CLEAR_SYMBOLS; }
  MapSymbolsAction OutputSymbolsAction() { return MAP_COPY_SYMBOLS; }
  uint64 Properties(uint64 props) const {
    // remove the following as we don't know now if any of them are true.
    uint64 to_remove = kAcceptor|kNotAcceptor|kIDeterministic|kNonIDeterministic|
        kNoEpsilons|kNoIEpsilons|kILabelSorted|kNotILabelSorted;
    return props & ~to_remove;
  }
  RemoveSomeInputSymbolsMapper(const vector<I> &to_remove):
      to_remove_set_(to_remove) {
    KALDI_ASSERT_IS_INTEGER_TYPE(I);
         assert(to_remove_set_.count(0) == 0);  // makes no sense to remove epsilon.
       }
private:
  kaldi::ConstIntegerSet<I> to_remove_set_;
};


template<class Arc, class I>
void CreateFactorFst(const vector<vector<I> > &sequences,
                     MutableFst<Arc> *fst) {
  KALDI_ASSERT_IS_INTEGER_TYPE(I);
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;

  assert(fst != NULL);
  fst->DeleteStates();
  StateId loopstate = fst->AddState();
  assert(loopstate == 0);
  fst->SetStart(0);
  fst->SetFinal(0, Weight::One());
  if (sequences.size() != 0) assert(sequences[0].size() == 0);  // can't replace epsilon...

  for (Label olabel = 1; olabel < static_cast<Label>(sequences.size()); olabel++) {
    size_t len = sequences[olabel].size();
    if (len == 0) {
      Arc arc(0, olabel, Weight::One(), loopstate);
      fst->AddArc(loopstate, arc);
    } else {
      StateId curstate = loopstate;
      for (size_t i = 0; i < len; i++) {
        StateId nextstate = (i == len-1 ? loopstate : fst->AddState());
        Arc arc(sequences[olabel][i], (i == 0 ? olabel : 0), Weight::One(), nextstate);
        fst->AddArc(curstate, arc);
        curstate = nextstate;
      }
    }
  }
  fst->SetProperties(kOLabelSorted, kOLabelSorted);
}


template<class Arc, class I>
void CreateMapFst(const vector<I> &symbol_map,
                  MutableFst<Arc> *fst) {
  KALDI_ASSERT_IS_INTEGER_TYPE(I);
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;

  assert(fst != NULL);
  fst->DeleteStates();
  StateId loopstate = fst->AddState();
  assert(loopstate == 0);
  fst->SetStart(0);
  fst->SetFinal(0, Weight::One());
  assert(symbol_map.empty() || symbol_map[0] == 0);  // FST cannot map epsilon to something else.
  for (Label olabel = 1; olabel < static_cast<Label>(symbol_map.size()); olabel++) {
    Arc arc(symbol_map[olabel], olabel, Weight::One(), loopstate);
    fst->AddArc(loopstate, arc);
  }
}




} // end namespace fst.

#endif