// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. #include namespace fst { namespace script { REGISTER_FST_WEIGHT(StdArc::Weight); REGISTER_FST_WEIGHT(LogArc::Weight); REGISTER_FST_WEIGHT(Log64Arc::Weight); WeightClass::WeightClass(const string &weight_type, const string &weight_str) { WeightClassRegister *reg = WeightClassRegister::GetRegister(); StrToWeightImplBaseT stw = reg->GetEntry(weight_type); if (!stw) { FSTERROR() << "Unknown weight type: " << weight_type; impl_.reset(); return; } impl_.reset(stw(weight_str, "WeightClass", 0)); } WeightClass WeightClass::Zero(const string &weight_type) { return WeightClass(weight_type, __ZERO__); } WeightClass WeightClass::One(const string &weight_type) { return WeightClass(weight_type, __ONE__); } WeightClass WeightClass::NoWeight(const string &weight_type) { return WeightClass(weight_type, __NOWEIGHT__); } bool WeightClass::WeightTypesMatch(const WeightClass &other, const string &op_name) const { if (Type() != other.Type()) { FSTERROR() << "Weights with non-matching types passed to " << op_name << ": " << Type() << " and " << other.Type(); return false; } return true; } bool operator==(const WeightClass &lhs, const WeightClass &rhs) { const auto *lhs_impl = lhs.GetImpl(); const auto *rhs_impl = rhs.GetImpl(); if (!(lhs_impl && rhs_impl && lhs.WeightTypesMatch(rhs, "operator=="))) { return false; } return *lhs_impl == *rhs_impl; } bool operator!=(const WeightClass &lhs, const WeightClass &rhs) { return !(lhs == rhs); } WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs) { const auto *rhs_impl = rhs.GetImpl(); if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Plus"))) { return WeightClass(); } WeightClass result(lhs); result.GetImpl()->PlusEq(*rhs_impl); return result; } WeightClass Times(const WeightClass &lhs, const WeightClass &rhs) { const auto *rhs_impl = rhs.GetImpl(); if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Plus"))) { return WeightClass(); } WeightClass result(lhs); result.GetImpl()->TimesEq(*rhs_impl); return result; } WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs) { const auto *rhs_impl = rhs.GetImpl(); if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Divide"))) { return WeightClass(); } WeightClass result(lhs); result.GetImpl()->DivideEq(*rhs_impl); return result; } WeightClass Power(const WeightClass &weight, size_t n) { if (!weight.GetImpl()) return WeightClass(); WeightClass result(weight); result.GetImpl()->PowerEq(n); return result; } std::ostream &operator<<(std::ostream &ostrm, const WeightClass &weight) { weight.impl_->Print(&ostrm); return ostrm; } } // namespace script } // namespace fst