// fstext/factor-inl.h // Copyright 2009-2011 Microsoft Corporation // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_FSTEXT_FACTOR_INL_H_ #define KALDI_FSTEXT_FACTOR_INL_H_ #include "util/stl-utils.h" // Do not include this file directly. It is included by factor.h. namespace fst { // GetStateProperties takes in an FST and a number "max_state" which is the // highest numbered state in the FST (this could be fst.NumStates()-1 for an // ExpandedFst, or derived from some kind of traversal). It outputs a vector // numbered from 0..max_state, of type FstStateProperties which is a bitmask // with information about the states. // GetStateProperties has not been tested directly (only implicitly via // testing Factor). template void GetStateProperties(const Fst &fst, typename Arc::StateId max_state, vector *props) { typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; assert(props != NULL); props->clear(); if (fst.Start() < 0) return; // Empty fst. props->resize(max_state+1, 0); assert(fst.Start() <= max_state); (*props)[fst.Start()] |= kStateInitial; for (StateId s = 0; s <= max_state; s++) { StatePropertiesType &s_info = (*props)[s]; for (ArcIterator > aiter(fst, s); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); if (arc.ilabel != 0) s_info |= kStateIlabelsOut; if (arc.olabel != 0) s_info |= kStateOlabelsOut; StateId nexts = arc.nextstate; assert(nexts <= max_state); // or input was invalid. StatePropertiesType &nexts_info = (*props)[nexts]; if (s_info&kStateArcsOut) s_info |= kStateMultipleArcsOut; s_info |= kStateArcsOut; if (nexts_info&kStateArcsIn) nexts_info |= kStateMultipleArcsIn; nexts_info |= kStateArcsIn; } if (fst.Final(s) != Weight::Zero()) s_info |= kStateFinal; } } template void Factor(const Fst &fst, MutableFst *ofst, vector > *symbols_out) { KALDI_ASSERT_IS_INTEGER_TYPE(I); typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; assert(symbols_out != NULL); ofst->DeleteStates(); if (fst.Start() < 0) return; // empty FST. vector order; DfsOrderVisitor dfs_order_visitor(&order); DfsVisit(fst, &dfs_order_visitor); assert(order.size() > 0); StateId max_state = *(std::max_element(order.begin(), order.end())); vector state_properties; GetStateProperties(fst, max_state, &state_properties); vector remove(max_state+1); // if true, will remove this state. // Now identify states that will be removed (made the middle of a chain). // The basic rule is that if the FstStateProperties equals // (kStateArcsIn|kStateArcsOut) or (kStateArcsIn|kStateArcsOut|kStateIlabelsOut), // then it is in the middle of a chain. This eliminates state with // multiple input or output arcs, final states, and states with arcs out // that have olabels [we assume these are pushed to the left, so occur on the // 1st arc of a chain. for (StateId i = 0; i <= max_state; i++) remove[i] = (state_properties[i] == (kStateArcsIn|kStateArcsOut) || state_properties[i] == (kStateArcsIn|kStateArcsOut|kStateIlabelsOut)); vector state_mapping(max_state+1, kNoStateId); typedef unordered_map, Label, kaldi::VectorHasher > SymbolMapType; SymbolMapType symbol_mapping; Label symbol_counter = 0; { vector eps; symbol_mapping[eps] = symbol_counter++; } vector this_sym; // a temporary used inside the loop. for (size_t i = 0; i < order.size(); i++) { StateId state = order[i]; if (!remove[state]) { // Process this state... StateId &new_state = state_mapping[state]; if (new_state == kNoStateId) new_state = ofst->AddState(); for (ArcIterator > aiter(fst, state); !aiter.Done(); aiter.Next()) { Arc arc = aiter.Value(); if (arc.ilabel == 0) this_sym.clear(); else { this_sym.resize(1); this_sym[0] = arc.ilabel; } while (remove[arc.nextstate]) { ArcIterator > aiter2(fst, arc.nextstate); assert(!aiter2.Done()); const Arc &nextarc = aiter2.Value(); arc.weight = Times(arc.weight, nextarc.weight); assert(nextarc.olabel == 0); if (nextarc.ilabel != 0) this_sym.push_back(nextarc.ilabel); assert(static_cast