Blame view
tools/openfst-1.6.7/src/include/fst/randequivalent.h
4.05 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Tests if two FSTS are equivalent by checking if random strings from one FST // are transduced the same by both FSTs. #ifndef FST_RANDEQUIVALENT_H_ #define FST_RANDEQUIVALENT_H_ #include <fst/log.h> #include <fst/arcsort.h> #include <fst/compose.h> #include <fst/project.h> #include <fst/randgen.h> #include <fst/shortest-distance.h> #include <fst/vector-fst.h> namespace fst { // Test if two FSTs are stochastically equivalent by randomly generating // random paths through the FSTs. // // For each randomly generated path, the algorithm computes for each // of the two FSTs the sum of the weights of all the successful paths // sharing the same input and output labels as the considered randomly // generated path and checks that these two values are within a user-specified // delta. Returns optional error value (when FLAGS_error_fatal = false). template <class Arc, class ArcSelector> bool RandEquivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2, int32 num_paths, float delta, const RandGenOptions<ArcSelector> &opts, bool *error = nullptr) { using Weight = typename Arc::Weight; if (error) *error = false; // Checks that the symbol table are compatible. if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols()) || !CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols())) { FSTERROR() << "RandEquivalent: Input/output symbol tables of 1st " << "argument do not match input/output symbol tables of 2nd " << "argument"; if (error) *error = true; return false; } static const ILabelCompare<Arc> icomp; static const OLabelCompare<Arc> ocomp; VectorFst<Arc> sfst1(fst1); VectorFst<Arc> sfst2(fst2); Connect(&sfst1); Connect(&sfst2); ArcSort(&sfst1, icomp); ArcSort(&sfst2, icomp); bool result = true; for (int32 n = 0; n < num_paths; ++n) { VectorFst<Arc> path; const auto &fst = rand() % 2 ? sfst1 : sfst2; // NOLINT RandGen(fst, &path, opts); VectorFst<Arc> ipath(path); VectorFst<Arc> opath(path); Project(&ipath, PROJECT_INPUT); Project(&opath, PROJECT_OUTPUT); VectorFst<Arc> cfst1, pfst1; Compose(ipath, sfst1, &cfst1); ArcSort(&cfst1, ocomp); Compose(cfst1, opath, &pfst1); // Gives up if there are epsilon cycles in a non-idempotent semiring. if (!(Weight::Properties() & kIdempotent) && pfst1.Properties(kCyclic, true)) { continue; } const auto sum1 = ShortestDistance(pfst1); VectorFst<Arc> cfst2; Compose(ipath, sfst2, &cfst2); ArcSort(&cfst2, ocomp); VectorFst<Arc> pfst2; Compose(cfst2, opath, &pfst2); // Gives up if there are epsilon cycles in a non-idempotent semiring. if (!(Weight::Properties() & kIdempotent) && pfst2.Properties(kCyclic, true)) { continue; } const auto sum2 = ShortestDistance(pfst2); if (!ApproxEqual(sum1, sum2, delta)) { VLOG(1) << "Sum1 = " << sum1; VLOG(1) << "Sum2 = " << sum2; result = false; break; } } if (fst1.Properties(kError, false) || fst2.Properties(kError, false)) { if (error) *error = true; return false; } return result; } // Tests if two FSTs are equivalent by randomly generating a nnum_paths paths // (no longer than the path_length) using a user-specified seed, optionally // indicating an error setting an optional error argument to true. template <class Arc> bool RandEquivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2, int32 num_paths, float delta = kDelta, time_t seed = time(nullptr), int32 max_length = std::numeric_limits<int32>::max(), bool *error = nullptr) { const UniformArcSelector<Arc> uniform_selector(seed); const RandGenOptions<UniformArcSelector<Arc>> opts(uniform_selector, max_length); return RandEquivalent(fst1, fst2, num_paths, delta, opts, error); } } // namespace fst #endif // FST_RANDEQUIVALENT_H_ |