Blame view
tools/openfst-1.6.7/src/include/fst/push.h
5.9 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Class to reweight/push an FST, and utility functions to weigh and reweight // an FST. #ifndef FST_PUSH_H_ #define FST_PUSH_H_ #include <vector> #include <fst/log.h> #include <fst/arc-map.h> #include <fst/factor-weight.h> #include <fst/fst.h> #include <fst/reweight.h> #include <fst/shortest-distance.h> namespace fst { // Computes the total weight (sum of the weights of all accepting paths) from // the output of ShortestDistance, using the shortest distance from the final // state when reverse is true and from the initial state otherwise. template <class Arc> typename Arc::Weight ComputeTotalWeight( const Fst<Arc> &fst, const std::vector<typename Arc::Weight> &distance, bool reverse) { if (reverse) { return fst.Start() < distance.size() ? distance[fst.Start()] : Arc::Weight::Zero(); } auto sum = Arc::Weight::Zero(); for (typename Arc::StateId s = 0; s < distance.size(); ++s) { sum = Plus(sum, Times(distance[s], fst.Final(s))); } return sum; } // Divides the weight of every accepting path by a fixed weight. This weight // is also divided at the final state if at_final is true and at the initial // state otherwise. template <class Arc> void RemoveWeight(MutableFst<Arc> *fst, const typename Arc::Weight &weight, bool at_final) { using Weight = typename Arc::Weight; if ((weight == Weight::One()) || (weight == Weight::Zero())) return; if (at_final) { for (StateIterator<MutableFst<Arc>> siter(*fst); !siter.Done(); siter.Next()) { fst->SetFinal(siter.Value(), Divide(fst->Final(siter.Value()), weight, DIVIDE_RIGHT)); } } else { const auto start = fst->Start(); for (MutableArcIterator<MutableFst<Arc>> aiter(fst, start); !aiter.Done(); aiter.Next()) { auto arc = aiter.Value(); arc.weight = Divide(arc.weight, weight, DIVIDE_LEFT); aiter.SetValue(arc); } fst->SetFinal(start, Divide(fst->Final(start), weight, DIVIDE_LEFT)); } } // Pushes the weights in FST in the direction defined by TYPE. If // pushing towards the initial state, the sum of the weight of the // outgoing transitions and final weight at a non-initial state is // equal to One() in the resulting machine. If pushing towards the // final state, the same property holds on the reverse machine. // // Weight needs to be left distributive when pushing towards the // initial state and right distributive when pushing towards the final // states. template <class Arc> void Push(MutableFst<Arc> *fst, ReweightType type, float delta = kDelta, bool remove_total_weight = false) { using Weight = typename Arc::Weight; std::vector<Weight> distance; ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta); auto total_weight = Weight::One(); if (remove_total_weight) { total_weight = ComputeTotalWeight(*fst, distance, type == REWEIGHT_TO_INITIAL); } Reweight(fst, distance, type); if (remove_total_weight) { RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL); } } constexpr uint32 kPushWeights = 0x0001; constexpr uint32 kPushLabels = 0x0002; constexpr uint32 kPushRemoveTotalWeight = 0x0004; constexpr uint32 kPushRemoveCommonAffix = 0x0008; // Pushes the weights and/or labels of the input FST into the output // mutable FST by pushing weights and/or labels (as determined by the // ptype argument) towards the initial state or final states (as // determined by the rtype template parameter). The weight type must // be left distributive when pushing weights towards the initial state, and // right distribution when pushing weights towards the final states. template <class Arc, ReweightType rtype> void Push(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, uint32 ptype, float delta = kDelta) { using Label = typename Arc::Label; using Weight = typename Arc::Weight; if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) { *ofst = ifst; Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight); } else if (ptype & kPushLabels) { const auto gtype = rtype == REWEIGHT_TO_INITIAL ? GALLIC_LEFT : GALLIC_RIGHT; using GallicWeight = typename GallicArc<Arc, gtype>::Weight; std::vector<GallicWeight> gdistance; VectorFst<GallicArc<Arc, gtype>> gfst; ArcMap(ifst, &gfst, ToGallicMapper<Arc, gtype>()); if (ptype & kPushWeights) { ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); } else { ArcMapFst<Arc, Arc, RmWeightMapper<Arc>> uwfst(ifst, RmWeightMapper<Arc>()); ArcMapFst<Arc, GallicArc<Arc, gtype>, ToGallicMapper<Arc, gtype>> guwfst( uwfst, ToGallicMapper<Arc, gtype>()); ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); } auto total_weight = GallicWeight::One(); if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) { total_weight = ComputeTotalWeight(gfst, gdistance, rtype == REWEIGHT_TO_INITIAL); total_weight = GallicWeight( ptype & kPushRemoveCommonAffix ? total_weight.Value1() : StringWeight<Label, GallicStringType(gtype)>::One(), ptype & kPushRemoveTotalWeight ? total_weight.Value2() : Weight::One()); } Reweight(&gfst, gdistance, rtype); if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) { RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL); } FactorWeightFst<GallicArc<Arc, gtype>, GallicFactor<Label, Weight, gtype>> fwfst(gfst); ArcMap(fwfst, ofst, FromGallicMapper<Arc, gtype>()); ofst->SetOutputSymbols(ifst.OutputSymbols()); } else { LOG(WARNING) << "Push: pushing type is set to 0, so not pushing"; *ofst = ifst; } } } // namespace fst #endif // FST_PUSH_H_ |