weight-class.cc
2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
#include <fst/script/weight-class.h>
namespace fst {
namespace script {
REGISTER_FST_WEIGHT(StdArc::Weight);
REGISTER_FST_WEIGHT(LogArc::Weight);
REGISTER_FST_WEIGHT(Log64Arc::Weight);
WeightClass::WeightClass(const string &weight_type, const string &weight_str) {
WeightClassRegister *reg = WeightClassRegister::GetRegister();
StrToWeightImplBaseT stw = reg->GetEntry(weight_type);
if (!stw) {
FSTERROR() << "Unknown weight type: " << weight_type;
impl_.reset();
return;
}
impl_.reset(stw(weight_str, "WeightClass", 0));
}
WeightClass WeightClass::Zero(const string &weight_type) {
return WeightClass(weight_type, __ZERO__);
}
WeightClass WeightClass::One(const string &weight_type) {
return WeightClass(weight_type, __ONE__);
}
WeightClass WeightClass::NoWeight(const string &weight_type) {
return WeightClass(weight_type, __NOWEIGHT__);
}
bool WeightClass::WeightTypesMatch(const WeightClass &other,
const string &op_name) const {
if (Type() != other.Type()) {
FSTERROR() << "Weights with non-matching types passed to " << op_name
<< ": " << Type() << " and " << other.Type();
return false;
}
return true;
}
bool operator==(const WeightClass &lhs, const WeightClass &rhs) {
const auto *lhs_impl = lhs.GetImpl();
const auto *rhs_impl = rhs.GetImpl();
if (!(lhs_impl && rhs_impl && lhs.WeightTypesMatch(rhs, "operator=="))) {
return false;
}
return *lhs_impl == *rhs_impl;
}
bool operator!=(const WeightClass &lhs, const WeightClass &rhs) {
return !(lhs == rhs);
}
WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs) {
const auto *rhs_impl = rhs.GetImpl();
if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Plus"))) {
return WeightClass();
}
WeightClass result(lhs);
result.GetImpl()->PlusEq(*rhs_impl);
return result;
}
WeightClass Times(const WeightClass &lhs, const WeightClass &rhs) {
const auto *rhs_impl = rhs.GetImpl();
if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Plus"))) {
return WeightClass();
}
WeightClass result(lhs);
result.GetImpl()->TimesEq(*rhs_impl);
return result;
}
WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs) {
const auto *rhs_impl = rhs.GetImpl();
if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Divide"))) {
return WeightClass();
}
WeightClass result(lhs);
result.GetImpl()->DivideEq(*rhs_impl);
return result;
}
WeightClass Power(const WeightClass &weight, size_t n) {
if (!weight.GetImpl()) return WeightClass();
WeightClass result(weight);
result.GetImpl()->PowerEq(n);
return result;
}
std::ostream &operator<<(std::ostream &ostrm, const WeightClass &weight) {
weight.impl_->Print(&ostrm);
return ostrm;
}
} // namespace script
} // namespace fst