Blame view

tools/openfst-1.6.7/include/fst/script/replace.h 2.57 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
  // See www.openfst.org for extensive documentation on this weighted
  // finite-state transducer library.
  
  #ifndef FST_SCRIPT_REPLACE_H_
  #define FST_SCRIPT_REPLACE_H_
  
  #include <tuple>
  #include <utility>
  #include <vector>
  
  #include <fst/replace.h>
  #include <fst/script/fst-class.h>
  
  namespace fst {
  namespace script {
  
  struct ReplaceOptions {
    const int64 root;                          // Root rule for expansion.
    const ReplaceLabelType call_label_type;    // How to label call arc.
    const ReplaceLabelType return_label_type;  // How to label return arc.
    const int64 return_label;                  // Specifies return arc label.
  
    explicit ReplaceOptions(int64 root,
        ReplaceLabelType call_label_type = REPLACE_LABEL_INPUT,
        ReplaceLabelType return_label_type = REPLACE_LABEL_NEITHER,
        int64 return_label = 0)
        : root(root),
          call_label_type(call_label_type),
          return_label_type(return_label_type),
          return_label(return_label) {}
  };
  
  using LabelFstClassPair = std::pair<int64, const FstClass *>;
  
  using ReplaceArgs = std::tuple<const std::vector<LabelFstClassPair> &,
                                 MutableFstClass *, const ReplaceOptions &>;
  
  template <class Arc>
  void Replace(ReplaceArgs *args) {
    using LabelFstPair = std::pair<typename Arc::Label, const Fst<Arc> *>;
    // Now that we know the arc type, we construct a vector of
    // std::pair<real label, real fst> that the real Replace will use.
    const auto &untyped_pairs = std::get<0>(*args);
    std::vector<LabelFstPair> typed_pairs;
    typed_pairs.reserve(untyped_pairs.size());
    for (const auto &untyped_pair : untyped_pairs) {
      typed_pairs.emplace_back(untyped_pair.first,  // Converts label.
                               untyped_pair.second->GetFst<Arc>());
    }
    MutableFst<Arc> *ofst = std::get<1>(*args)->GetMutableFst<Arc>();
    const auto &opts = std::get<2>(*args);
    ReplaceFstOptions<Arc> typed_opts(opts.root, opts.call_label_type,
                                      opts.return_label_type, opts.return_label);
    ReplaceFst<Arc> rfst(typed_pairs, typed_opts);
    // Checks for cyclic dependencies before attempting expansion.
    if (rfst.CyclicDependencies()) {
      FSTERROR() << "Replace: Cyclic dependencies detected; cannot expand";
      ofst->SetProperties(kError, kError);
      return;
    }
    typed_opts.gc = true;     // Caching options to speed up batch copy.
    typed_opts.gc_limit = 0;
    *ofst = rfst;
  }
  
  void Replace(const std::vector<LabelFstClassPair> &pairs,
               MutableFstClass *ofst, const ReplaceOptions &opts);
  
  }  // namespace script
  }  // namespace fst
  
  #endif  // FST_SCRIPT_REPLACE_H_