// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Regression test for various FST algorithms. #ifndef FST_TEST_ALGO_TEST_H_ #define FST_TEST_ALGO_TEST_H_ #include #include #include "./rand-fst.h" DECLARE_int32(repeat); // defined in ./algo_test.cc namespace fst { // Mapper to change input and output label of every transition into // epsilons. template class EpsMapper { public: EpsMapper() {} A operator()(const A &arc) const { return A(0, 0, arc.weight, arc.nextstate); } uint64 Properties(uint64 props) const { props &= ~kNotAcceptor; props |= kAcceptor; props &= ~kNoIEpsilons & ~kNoOEpsilons & ~kNoEpsilons; props |= kIEpsilons | kOEpsilons | kEpsilons; props &= ~kNotILabelSorted & ~kNotOLabelSorted; props |= kILabelSorted | kOLabelSorted; return props; } MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; } }; // Generic - no lookahead. template void LookAheadCompose(const Fst &ifst1, const Fst &ifst2, MutableFst *ofst) { Compose(ifst1, ifst2, ofst); } // Specialized and epsilon olabel acyclic - lookahead. void LookAheadCompose(const Fst &ifst1, const Fst &ifst2, MutableFst *ofst) { std::vector order; bool acyclic; TopOrderVisitor visitor(&order, &acyclic); DfsVisit(ifst1, &visitor, OutputEpsilonArcFilter()); if (acyclic) { // no ifst1 output epsilon cycles? StdOLabelLookAheadFst lfst1(ifst1); StdVectorFst lfst2(ifst2); LabelLookAheadRelabeler::Relabel(&lfst2, lfst1, true); Compose(lfst1, lfst2, ofst); } else { Compose(ifst1, ifst2, ofst); } } // This class tests a variety of identities and properties that must // hold for various algorithms on weighted FSTs. template class WeightedTester { public: typedef typename Arc::Label Label; typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; WeightedTester(time_t seed, const Fst &zero_fst, const Fst &one_fst, const Fst &univ_fst, WeightGenerator *weight_generator) : seed_(seed), zero_fst_(zero_fst), one_fst_(one_fst), univ_fst_(univ_fst), weight_generator_(weight_generator) {} void Test(const Fst &T1, const Fst &T2, const Fst &T3) { TestRational(T1, T2, T3); TestMap(T1); TestCompose(T1, T2, T3); TestSort(T1); TestOptimize(T1); TestSearch(T1); } private: // Tests rational operations with identities void TestRational(const Fst &T1, const Fst &T2, const Fst &T3) { { VLOG(1) << "Check destructive and delayed union are equivalent."; VectorFst U1(T1); Union(&U1, T2); UnionFst U2(T1, T2); CHECK(Equiv(U1, U2)); } { VLOG(1) << "Check destructive and delayed concatenation are equivalent."; VectorFst C1(T1); Concat(&C1, T2); ConcatFst C2(T1, T2); CHECK(Equiv(C1, C2)); VectorFst C3(T2); Concat(T1, &C3); CHECK(Equiv(C3, C2)); } { VLOG(1) << "Check destructive and delayed closure* are equivalent."; VectorFst C1(T1); Closure(&C1, CLOSURE_STAR); ClosureFst C2(T1, CLOSURE_STAR); CHECK(Equiv(C1, C2)); } { VLOG(1) << "Check destructive and delayed closure+ are equivalent."; VectorFst C1(T1); Closure(&C1, CLOSURE_PLUS); ClosureFst C2(T1, CLOSURE_PLUS); CHECK(Equiv(C1, C2)); } { VLOG(1) << "Check union is associative (destructive)."; VectorFst U1(T1); Union(&U1, T2); Union(&U1, T3); VectorFst U3(T2); Union(&U3, T3); VectorFst U4(T1); Union(&U4, U3); CHECK(Equiv(U1, U4)); } { VLOG(1) << "Check union is associative (delayed)."; UnionFst U1(T1, T2); UnionFst U2(U1, T3); UnionFst U3(T2, T3); UnionFst U4(T1, U3); CHECK(Equiv(U2, U4)); } { VLOG(1) << "Check union is associative (destructive delayed)."; UnionFst U1(T1, T2); Union(&U1, T3); UnionFst U3(T2, T3); UnionFst U4(T1, U3); CHECK(Equiv(U1, U4)); } { VLOG(1) << "Check concatenation is associative (destructive)."; VectorFst C1(T1); Concat(&C1, T2); Concat(&C1, T3); VectorFst C3(T2); Concat(&C3, T3); VectorFst C4(T1); Concat(&C4, C3); CHECK(Equiv(C1, C4)); } { VLOG(1) << "Check concatenation is associative (delayed)."; ConcatFst C1(T1, T2); ConcatFst C2(C1, T3); ConcatFst C3(T2, T3); ConcatFst C4(T1, C3); CHECK(Equiv(C2, C4)); } { VLOG(1) << "Check concatenation is associative (destructive delayed)."; ConcatFst C1(T1, T2); Concat(&C1, T3); ConcatFst C3(T2, T3); ConcatFst C4(T1, C3); CHECK(Equiv(C1, C4)); } if (Weight::Properties() & kLeftSemiring) { VLOG(1) << "Check concatenation left distributes" << " over union (destructive)."; VectorFst U1(T1); Union(&U1, T2); VectorFst C1(T3); Concat(&C1, U1); VectorFst C2(T3); Concat(&C2, T1); VectorFst C3(T3); Concat(&C3, T2); VectorFst U2(C2); Union(&U2, C3); CHECK(Equiv(C1, U2)); } if (Weight::Properties() & kRightSemiring) { VLOG(1) << "Check concatenation right distributes" << " over union (destructive)."; VectorFst U1(T1); Union(&U1, T2); VectorFst C1(U1); Concat(&C1, T3); VectorFst C2(T1); Concat(&C2, T3); VectorFst C3(T2); Concat(&C3, T3); VectorFst U2(C2); Union(&U2, C3); CHECK(Equiv(C1, U2)); } if (Weight::Properties() & kLeftSemiring) { VLOG(1) << "Check concatenation left distributes over union (delayed)."; UnionFst U1(T1, T2); ConcatFst C1(T3, U1); ConcatFst C2(T3, T1); ConcatFst C3(T3, T2); UnionFst U2(C2, C3); CHECK(Equiv(C1, U2)); } if (Weight::Properties() & kRightSemiring) { VLOG(1) << "Check concatenation right distributes over union (delayed)."; UnionFst U1(T1, T2); ConcatFst C1(U1, T3); ConcatFst C2(T1, T3); ConcatFst C3(T2, T3); UnionFst U2(C2, C3); CHECK(Equiv(C1, U2)); } if (Weight::Properties() & kLeftSemiring) { VLOG(1) << "Check T T* == T+ (destructive)."; VectorFst S(T1); Closure(&S, CLOSURE_STAR); VectorFst C(T1); Concat(&C, S); VectorFst P(T1); Closure(&P, CLOSURE_PLUS); CHECK(Equiv(C, P)); } if (Weight::Properties() & kRightSemiring) { VLOG(1) << "Check T* T == T+ (destructive)."; VectorFst S(T1); Closure(&S, CLOSURE_STAR); VectorFst C(S); Concat(&C, T1); VectorFst P(T1); Closure(&P, CLOSURE_PLUS); CHECK(Equiv(C, P)); } if (Weight::Properties() & kLeftSemiring) { VLOG(1) << "Check T T* == T+ (delayed)."; ClosureFst S(T1, CLOSURE_STAR); ConcatFst C(T1, S); ClosureFst P(T1, CLOSURE_PLUS); CHECK(Equiv(C, P)); } if (Weight::Properties() & kRightSemiring) { VLOG(1) << "Check T* T == T+ (delayed)."; ClosureFst S(T1, CLOSURE_STAR); ConcatFst C(S, T1); ClosureFst P(T1, CLOSURE_PLUS); CHECK(Equiv(C, P)); } } // Tests map-based operations. void TestMap(const Fst &T) { { VLOG(1) << "Check destructive and delayed projection are equivalent."; VectorFst P1(T); Project(&P1, PROJECT_INPUT); ProjectFst P2(T, PROJECT_INPUT); CHECK(Equiv(P1, P2)); } { VLOG(1) << "Check destructive and delayed inversion are equivalent."; VectorFst I1(T); Invert(&I1); InvertFst I2(T); CHECK(Equiv(I1, I2)); } { VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (destructive)."; VectorFst P1(T); VectorFst I1(T); Project(&P1, PROJECT_INPUT); Invert(&I1); Project(&I1, PROJECT_OUTPUT); CHECK(Equiv(P1, I1)); } { VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (destructive)."; VectorFst P1(T); VectorFst I1(T); Project(&P1, PROJECT_OUTPUT); Invert(&I1); Project(&I1, PROJECT_INPUT); CHECK(Equiv(P1, I1)); } { VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (delayed)."; ProjectFst P1(T, PROJECT_INPUT); InvertFst I1(T); ProjectFst P2(I1, PROJECT_OUTPUT); CHECK(Equiv(P1, P2)); } { VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (delayed)."; ProjectFst P1(T, PROJECT_OUTPUT); InvertFst I1(T); ProjectFst P2(I1, PROJECT_INPUT); CHECK(Equiv(P1, P2)); } { VLOG(1) << "Check destructive relabeling"; static const int kNumLabels = 10; // set up relabeling pairs std::vector