Blame view

tools/openfst-1.6.7/src/include/fst/expanded-fst.h 5.39 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
  // See www.openfst.org for extensive documentation on this weighted
  // finite-state transducer library.
  //
  // Generic FST augmented with state count-interface class definition.
  
  #ifndef FST_EXPANDED_FST_H_
  #define FST_EXPANDED_FST_H_
  
  #include <sys/types.h>
  #include <istream>
  #include <string>
  
  #include <fst/log.h>
  #include <fstream>
  
  #include <fst/fst.h>
  
  
  namespace fst {
  
  // A generic FST plus state count.
  template <class A>
  class ExpandedFst : public Fst<A> {
   public:
    using Arc = A;
    using StateId = typename Arc::StateId;
  
    virtual StateId NumStates() const = 0;  // State count
  
    // Get a copy of this ExpandedFst. See Fst<>::Copy() for further doc.
    ExpandedFst<Arc> *Copy(bool safe = false) const override = 0;
  
    // Read an ExpandedFst from an input stream; return NULL on error.
    static ExpandedFst<Arc> *Read(std::istream &strm,
                                  const FstReadOptions &opts) {
      FstReadOptions ropts(opts);
      FstHeader hdr;
      if (ropts.header) {
        hdr = *opts.header;
      } else {
        if (!hdr.Read(strm, opts.source)) return nullptr;
        ropts.header = &hdr;
      }
      if (!(hdr.Properties() & kExpanded)) {
        LOG(ERROR) << "ExpandedFst::Read: Not an ExpandedFst: " << ropts.source;
        return nullptr;
      }
      const auto reader =
          FstRegister<Arc>::GetRegister()->GetReader(hdr.FstType());
      if (!reader) {
        LOG(ERROR) << "ExpandedFst::Read: Unknown FST type \"" << hdr.FstType()
                   << "\" (arc type = \"" << A::Type() << "\"): " << ropts.source;
        return nullptr;
      }
      auto *fst = reader(strm, ropts);
      if (!fst) return nullptr;
      return static_cast<ExpandedFst<Arc> *>(fst);
    }
  
    // Read an ExpandedFst from a file; return NULL on error.
    // Empty filename reads from standard input.
    static ExpandedFst<Arc> *Read(const string &filename) {
      if (!filename.empty()) {
        std::ifstream strm(filename,
                                std::ios_base::in | std::ios_base::binary);
        if (!strm) {
          LOG(ERROR) << "ExpandedFst::Read: Can't open file: " << filename;
          return nullptr;
        }
        return Read(strm, FstReadOptions(filename));
      } else {
        return Read(std::cin, FstReadOptions("standard input"));
      }
    }
  };
  
  namespace internal {
  
  //  ExpandedFst<A> case - abstract methods.
  template <class Arc>
  inline typename Arc::Weight Final(const ExpandedFst<Arc> &fst,
                                    typename Arc::StateId s) {
    return fst.Final(s);
  }
  
  template <class Arc>
  inline ssize_t NumArcs(const ExpandedFst<Arc> &fst, typename Arc::StateId s) {
    return fst.NumArcs(s);
  }
  
  template <class Arc>
  inline ssize_t NumInputEpsilons(const ExpandedFst<Arc> &fst,
                                  typename Arc::StateId s) {
    return fst.NumInputEpsilons(s);
  }
  
  template <class Arc>
  inline ssize_t NumOutputEpsilons(const ExpandedFst<Arc> &fst,
                                   typename Arc::StateId s) {
    return fst.NumOutputEpsilons(s);
  }
  
  }  // namespace internal
  
  // A useful alias when using StdArc.
  using StdExpandedFst = ExpandedFst<StdArc>;
  
  // This is a helper class template useful for attaching an ExpandedFst
  // interface to its implementation, handling reference counting. It
  // delegates to ImplToFst the handling of the Fst interface methods.
  template <class Impl, class FST = ExpandedFst<typename Impl::Arc>>
  class ImplToExpandedFst : public ImplToFst<Impl, FST> {
   public:
    using Arc = typename FST::Arc;
    using StateId = typename Arc::StateId;
    using Weight = typename Arc::Weight;
  
    using ImplToFst<Impl, FST>::operator=;
  
    StateId NumStates() const override { return GetImpl()->NumStates(); }
  
   protected:
    using ImplToFst<Impl, FST>::GetImpl;
  
    explicit ImplToExpandedFst(std::shared_ptr<Impl> impl)
        : ImplToFst<Impl, FST>(impl) {}
  
    ImplToExpandedFst(const ImplToExpandedFst<Impl, FST> &fst)
        : ImplToFst<Impl, FST>(fst) {}
  
    ImplToExpandedFst(const ImplToExpandedFst<Impl, FST> &fst, bool safe)
        : ImplToFst<Impl, FST>(fst, safe) {}
  
    static Impl *Read(std::istream &strm, const FstReadOptions &opts) {
      return Impl::Read(strm, opts);
    }
  
    // Read FST implementation from a file; return NULL on error.
    // Empty filename reads from standard input.
    static Impl *Read(const string &filename) {
      if (!filename.empty()) {
        std::ifstream strm(filename,
                                std::ios_base::in | std::ios_base::binary);
        if (!strm) {
          LOG(ERROR) << "ExpandedFst::Read: Can't open file: " << filename;
          return nullptr;
        }
        return Impl::Read(strm, FstReadOptions(filename));
      } else {
        return Impl::Read(std::cin, FstReadOptions("standard input"));
      }
    }
  };
  
  // Function to return the number of states in an FST, counting them
  // if necessary.
  template <class Arc>
  typename Arc::StateId CountStates(const Fst<Arc> &fst) {
    if (fst.Properties(kExpanded, false)) {
      const auto *efst = static_cast<const ExpandedFst<Arc> *>(&fst);
      return efst->NumStates();
    } else {
      typename Arc::StateId nstates = 0;
      for (StateIterator<Fst<Arc>> siter(fst); !siter.Done(); siter.Next()) {
        ++nstates;
      }
      return nstates;
    }
  }
  
  // Function to return the number of arcs in an FST.
  template <class Arc>
  typename Arc::StateId CountArcs(const Fst<Arc> &fst) {
    size_t narcs = 0;
    for (StateIterator<Fst<Arc>> siter(fst); !siter.Done(); siter.Next()) {
      narcs += fst.NumArcs(siter.Value());
    }
    return narcs;
  }
  
  }  // namespace fst
  
  #endif  // FST_EXPANDED_FST_H_