Blame view
tools/openfst-1.6.7/src/script/weight-class.cc
2.88 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. #include <fst/script/weight-class.h> namespace fst { namespace script { REGISTER_FST_WEIGHT(StdArc::Weight); REGISTER_FST_WEIGHT(LogArc::Weight); REGISTER_FST_WEIGHT(Log64Arc::Weight); WeightClass::WeightClass(const string &weight_type, const string &weight_str) { WeightClassRegister *reg = WeightClassRegister::GetRegister(); StrToWeightImplBaseT stw = reg->GetEntry(weight_type); if (!stw) { FSTERROR() << "Unknown weight type: " << weight_type; impl_.reset(); return; } impl_.reset(stw(weight_str, "WeightClass", 0)); } WeightClass WeightClass::Zero(const string &weight_type) { return WeightClass(weight_type, __ZERO__); } WeightClass WeightClass::One(const string &weight_type) { return WeightClass(weight_type, __ONE__); } WeightClass WeightClass::NoWeight(const string &weight_type) { return WeightClass(weight_type, __NOWEIGHT__); } bool WeightClass::WeightTypesMatch(const WeightClass &other, const string &op_name) const { if (Type() != other.Type()) { FSTERROR() << "Weights with non-matching types passed to " << op_name << ": " << Type() << " and " << other.Type(); return false; } return true; } bool operator==(const WeightClass &lhs, const WeightClass &rhs) { const auto *lhs_impl = lhs.GetImpl(); const auto *rhs_impl = rhs.GetImpl(); if (!(lhs_impl && rhs_impl && lhs.WeightTypesMatch(rhs, "operator=="))) { return false; } return *lhs_impl == *rhs_impl; } bool operator!=(const WeightClass &lhs, const WeightClass &rhs) { return !(lhs == rhs); } WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs) { const auto *rhs_impl = rhs.GetImpl(); if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Plus"))) { return WeightClass(); } WeightClass result(lhs); result.GetImpl()->PlusEq(*rhs_impl); return result; } WeightClass Times(const WeightClass &lhs, const WeightClass &rhs) { const auto *rhs_impl = rhs.GetImpl(); if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Plus"))) { return WeightClass(); } WeightClass result(lhs); result.GetImpl()->TimesEq(*rhs_impl); return result; } WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs) { const auto *rhs_impl = rhs.GetImpl(); if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Divide"))) { return WeightClass(); } WeightClass result(lhs); result.GetImpl()->DivideEq(*rhs_impl); return result; } WeightClass Power(const WeightClass &weight, size_t n) { if (!weight.GetImpl()) return WeightClass(); WeightClass result(weight); result.GetImpl()->PowerEq(n); return result; } std::ostream &operator<<(std::ostream &ostrm, const WeightClass &weight) { weight.impl_->Print(&ostrm); return ostrm; } } // namespace script } // namespace fst |