Blame view

tools/openfst-1.6.7/src/script/weight-class.cc 2.88 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
  // See www.openfst.org for extensive documentation on this weighted
  // finite-state transducer library.
  
  #include <fst/script/weight-class.h>
  
  namespace fst {
  namespace script {
  
  REGISTER_FST_WEIGHT(StdArc::Weight);
  REGISTER_FST_WEIGHT(LogArc::Weight);
  REGISTER_FST_WEIGHT(Log64Arc::Weight);
  
  WeightClass::WeightClass(const string &weight_type, const string &weight_str) {
    WeightClassRegister *reg = WeightClassRegister::GetRegister();
    StrToWeightImplBaseT stw = reg->GetEntry(weight_type);
    if (!stw) {
      FSTERROR() << "Unknown weight type: " << weight_type;
      impl_.reset();
      return;
    }
    impl_.reset(stw(weight_str, "WeightClass", 0));
  }
  
  WeightClass WeightClass::Zero(const string &weight_type) {
    return WeightClass(weight_type, __ZERO__);
  }
  
  WeightClass WeightClass::One(const string &weight_type) {
    return WeightClass(weight_type, __ONE__);
  }
  
  WeightClass WeightClass::NoWeight(const string &weight_type) {
    return WeightClass(weight_type, __NOWEIGHT__);
  }
  
  bool WeightClass::WeightTypesMatch(const WeightClass &other,
                                     const string &op_name) const {
    if (Type() != other.Type()) {
      FSTERROR() << "Weights with non-matching types passed to " << op_name
                 << ": " << Type() << " and " << other.Type();
      return false;
    }
    return true;
  }
  
  bool operator==(const WeightClass &lhs, const WeightClass &rhs) {
    const auto *lhs_impl = lhs.GetImpl();
    const auto *rhs_impl = rhs.GetImpl();
    if (!(lhs_impl && rhs_impl && lhs.WeightTypesMatch(rhs, "operator=="))) {
      return false;
    }
    return *lhs_impl == *rhs_impl;
  }
  
  bool operator!=(const WeightClass &lhs, const WeightClass &rhs) {
    return !(lhs == rhs);
  }
  
  WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs) {
    const auto *rhs_impl = rhs.GetImpl();
    if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Plus"))) {
      return WeightClass();
    }
    WeightClass result(lhs);
    result.GetImpl()->PlusEq(*rhs_impl);
    return result;
  }
  
  WeightClass Times(const WeightClass &lhs, const WeightClass &rhs) {
    const auto *rhs_impl = rhs.GetImpl();
    if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Plus"))) {
      return WeightClass();
    }
    WeightClass result(lhs);
    result.GetImpl()->TimesEq(*rhs_impl);
    return result;
  }
  
  WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs) {
    const auto *rhs_impl = rhs.GetImpl();
    if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Divide"))) {
      return WeightClass();
    }
    WeightClass result(lhs);
    result.GetImpl()->DivideEq(*rhs_impl);
    return result;
  }
  
  WeightClass Power(const WeightClass &weight, size_t n) {
    if (!weight.GetImpl()) return WeightClass();
    WeightClass result(weight);
    result.GetImpl()->PowerEq(n);
    return result;
  }
  
  std::ostream &operator<<(std::ostream &ostrm, const WeightClass &weight) {
    weight.impl_->Print(&ostrm);
    return ostrm;
  }
  
  }  // namespace script
  }  // namespace fst