// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Float weight set and associated semiring operation definitions. #ifndef FST_FLOAT_WEIGHT_H_ #define FST_FLOAT_WEIGHT_H_ #include #include #include #include #include #include #include #include #include #include namespace fst { // Numeric limits class. template class FloatLimits { public: static constexpr T PosInfinity() { return std::numeric_limits::infinity(); } static constexpr T NegInfinity() { return -PosInfinity(); } static constexpr T NumberBad() { return std::numeric_limits::quiet_NaN(); } }; // Weight class to be templated on floating-points types. template class FloatWeightTpl { public: using ValueType = T; FloatWeightTpl() {} FloatWeightTpl(T f) : value_(f) {} FloatWeightTpl(const FloatWeightTpl &weight) : value_(weight.value_) {} FloatWeightTpl &operator=(const FloatWeightTpl &weight) { value_ = weight.value_; return *this; } std::istream &Read(std::istream &strm) { return ReadType(strm, &value_); } std::ostream &Write(std::ostream &strm) const { return WriteType(strm, value_); } size_t Hash() const { size_t hash = 0; // Avoid using union, which would be undefined behavior. // Use memcpy, similar to bit_cast, but sizes may be different. // This should be optimized into a single move instruction by // any reasonable compiler. std::memcpy(&hash, &value_, std::min(sizeof(hash), sizeof(value_))); return hash; } const T &Value() const { return value_; } protected: void SetValue(const T &f) { value_ = f; } static constexpr const char *GetPrecisionString() { return sizeof(T) == 4 ? "" : sizeof(T) == 1 ? "8" : sizeof(T) == 2 ? "16" : sizeof(T) == 8 ? "64" : "unknown"; } private: T value_; }; // Single-precision float weight. using FloatWeight = FloatWeightTpl; template inline bool operator==(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { // Volatile qualifier thwarts over-aggressive compiler optimizations that // lead to problems esp. with NaturalLess(). volatile T v1 = w1.Value(); volatile T v2 = w2.Value(); return v1 == v2; } // These seemingly unnecessary overloads are actually needed to make // comparisons like FloatWeightTpl == float compile. If only the // templated version exists, the FloatWeightTpl(float) conversion // won't be found. inline bool operator==(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return operator==(w1, w2); } inline bool operator==(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return operator==(w1, w2); } template inline bool operator!=(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return !(w1 == w2); } inline bool operator!=(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return operator!=(w1, w2); } inline bool operator!=(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return operator!=(w1, w2); } template inline bool ApproxEqual(const FloatWeightTpl &w1, const FloatWeightTpl &w2, float delta = kDelta) { return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta; } template inline std::ostream &operator<<(std::ostream &strm, const FloatWeightTpl &w) { if (w.Value() == FloatLimits::PosInfinity()) { return strm << "Infinity"; } else if (w.Value() == FloatLimits::NegInfinity()) { return strm << "-Infinity"; } else if (w.Value() != w.Value()) { // Fails for IEEE NaN. return strm << "BadNumber"; } else { return strm << w.Value(); } } template inline std::istream &operator>>(std::istream &strm, FloatWeightTpl &w) { string s; strm >> s; if (s == "Infinity") { w = FloatWeightTpl(FloatLimits::PosInfinity()); } else if (s == "-Infinity") { w = FloatWeightTpl(FloatLimits::NegInfinity()); } else { char *p; T f = strtod(s.c_str(), &p); if (p < s.c_str() + s.size()) { strm.clear(std::ios::badbit); } else { w = FloatWeightTpl(f); } } return strm; } // Tropical semiring: (min, +, inf, 0). template class TropicalWeightTpl : public FloatWeightTpl { public: using typename FloatWeightTpl::ValueType; using FloatWeightTpl::Value; using ReverseWeight = TropicalWeightTpl; using Limits = FloatLimits; constexpr TropicalWeightTpl() : FloatWeightTpl() {} constexpr TropicalWeightTpl(T f) : FloatWeightTpl(f) {} constexpr TropicalWeightTpl(const TropicalWeightTpl &weight) : FloatWeightTpl(weight) {} static const TropicalWeightTpl &Zero() { static const TropicalWeightTpl zero(Limits::PosInfinity()); return zero; } static const TropicalWeightTpl &One() { static const TropicalWeightTpl one(0.0F); return one; } static const TropicalWeightTpl &NoWeight() { static const TropicalWeightTpl no_weight(Limits::NumberBad()); return no_weight; } static const string &Type() { static const string *const type = new string(string("tropical") + FloatWeightTpl::GetPrecisionString()); return *type; } bool Member() const { // First part fails for IEEE NaN. return Value() == Value() && Value() != Limits::NegInfinity(); } TropicalWeightTpl Quantize(float delta = kDelta) const { if (!Member() || Value() == Limits::PosInfinity()) { return *this; } else { return TropicalWeightTpl(floor(Value() / delta + 0.5F) * delta); } } TropicalWeightTpl Reverse() const { return *this; } static constexpr uint64 Properties() { return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent; } }; // Single precision tropical weight. using TropicalWeight = TropicalWeightTpl; template inline TropicalWeightTpl Plus(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { if (!w1.Member() || !w2.Member()) return TropicalWeightTpl::NoWeight(); return w1.Value() < w2.Value() ? w1 : w2; } // See comment at operator==(FloatWeightTpl, FloatWeightTpl) // for why these overloads are present. inline TropicalWeightTpl Plus(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { return Plus(w1, w2); } inline TropicalWeightTpl Plus(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { return Plus(w1, w2); } template inline TropicalWeightTpl Times(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { using Limits = FloatLimits; if (!w1.Member() || !w2.Member()) return TropicalWeightTpl::NoWeight(); const T f1 = w1.Value(); const T f2 = w2.Value(); if (f1 == Limits::PosInfinity()) { return w1; } else if (f2 == Limits::PosInfinity()) { return w2; } else { return TropicalWeightTpl(f1 + f2); } } inline TropicalWeightTpl Times(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { return Times(w1, w2); } inline TropicalWeightTpl Times(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { return Times(w1, w2); } template inline TropicalWeightTpl Divide(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2, DivideType typ = DIVIDE_ANY) { using Limits = FloatLimits; if (!w1.Member() || !w2.Member()) return TropicalWeightTpl::NoWeight(); const T f1 = w1.Value(); const T f2 = w2.Value(); if (f2 == Limits::PosInfinity()) { return Limits::NumberBad(); } else if (f1 == Limits::PosInfinity()) { return Limits::PosInfinity(); } else { return TropicalWeightTpl(f1 - f2); } } inline TropicalWeightTpl Divide(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } inline TropicalWeightTpl Divide(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } template inline TropicalWeightTpl Power(const TropicalWeightTpl &weight, V n) { if (n == 0) { return TropicalWeightTpl::One(); } else if (weight == TropicalWeightTpl::Zero()) { return TropicalWeightTpl::Zero(); } return TropicalWeightTpl(weight.Value() * n); } // Specializes the library-wide template to use the above implementation; rules // of function template instantiation require this be a full instantiation. template <> inline TropicalWeightTpl Power>( const TropicalWeightTpl &weight, size_t n) { return Power(weight, n); } template <> inline TropicalWeightTpl Power>( const TropicalWeightTpl &weight, size_t n) { return Power(weight, n); } // Log semiring: (log(e^-x + e^-y), +, inf, 0). template class LogWeightTpl : public FloatWeightTpl { public: using typename FloatWeightTpl::ValueType; using FloatWeightTpl::Value; using ReverseWeight = LogWeightTpl; using Limits = FloatLimits; constexpr LogWeightTpl() : FloatWeightTpl() {} constexpr LogWeightTpl(T f) : FloatWeightTpl(f) {} constexpr LogWeightTpl(const LogWeightTpl &weight) : FloatWeightTpl(weight) {} static const LogWeightTpl &Zero() { static const LogWeightTpl zero(Limits::PosInfinity()); return zero; } static const LogWeightTpl &One() { static const LogWeightTpl one(0.0F); return one; } static const LogWeightTpl &NoWeight() { static const LogWeightTpl no_weight(Limits::NumberBad()); return no_weight; } static const string &Type() { static const string *const type = new string(string("log") + FloatWeightTpl::GetPrecisionString()); return *type; } bool Member() const { // First part fails for IEEE NaN. return Value() == Value() && Value() != Limits::NegInfinity(); } LogWeightTpl Quantize(float delta = kDelta) const { if (!Member() || Value() == Limits::PosInfinity()) { return *this; } else { return LogWeightTpl(floor(Value() / delta + 0.5F) * delta); } } LogWeightTpl Reverse() const { return *this; } static constexpr uint64 Properties() { return kLeftSemiring | kRightSemiring | kCommutative; } }; // Single-precision log weight. using LogWeight = LogWeightTpl; // Double-precision log weight. using Log64Weight = LogWeightTpl; namespace internal { // -log(e^-x + e^-y) = x - LogPosExp(y - x), assuming x >= 0.0. inline double LogPosExp(double x) { DCHECK(!(x < 0)); // NB: NaN values are allowed. return log1p(exp(-x)); } // -log(e^-x - e^-y) = x - LogNegExp(y - x), assuming x > 0.0. inline double LogNegExp(double x) { DCHECK_GT(x, 0); return log1p(-exp(-x)); } // a +_log b = -log(e^-a + e^-b) = KahanLogSum(a, b, ...). // Kahan compensated summation provides an error bound that is // independent of the number of addends. Assumes b >= a; // c is the compensation. inline double KahanLogSum(double a, double b, double *c) { DCHECK_GE(b, a); double y = -LogPosExp(b - a) - *c; double t = a + y; *c = (t - a) - y; return t; } // a -_log b = -log(e^-a - e^-b) = KahanLogDiff(a, b, ...). // Kahan compensated summation provides an error bound that is // independent of the number of addends. Assumes b > a; // c is the compensation. inline double KahanLogDiff(double a, double b, double *c) { DCHECK_GT(b, a); double y = -LogNegExp(b - a) - *c; double t = a + y; *c = (t - a) - y; return t; } } // namespace internal template inline LogWeightTpl Plus(const LogWeightTpl &w1, const LogWeightTpl &w2) { using Limits = FloatLimits; const T f1 = w1.Value(); const T f2 = w2.Value(); if (f1 == Limits::PosInfinity()) { return w2; } else if (f2 == Limits::PosInfinity()) { return w1; } else if (f1 > f2) { return LogWeightTpl(f2 - internal::LogPosExp(f1 - f2)); } else { return LogWeightTpl(f1 - internal::LogPosExp(f2 - f1)); } } inline LogWeightTpl Plus(const LogWeightTpl &w1, const LogWeightTpl &w2) { return Plus(w1, w2); } inline LogWeightTpl Plus(const LogWeightTpl &w1, const LogWeightTpl &w2) { return Plus(w1, w2); } template inline LogWeightTpl Times(const LogWeightTpl &w1, const LogWeightTpl &w2) { using Limits = FloatLimits; if (!w1.Member() || !w2.Member()) return LogWeightTpl::NoWeight(); const T f1 = w1.Value(); const T f2 = w2.Value(); if (f1 == Limits::PosInfinity()) { return w1; } else if (f2 == Limits::PosInfinity()) { return w2; } else { return LogWeightTpl(f1 + f2); } } inline LogWeightTpl Times(const LogWeightTpl &w1, const LogWeightTpl &w2) { return Times(w1, w2); } inline LogWeightTpl Times(const LogWeightTpl &w1, const LogWeightTpl &w2) { return Times(w1, w2); } template inline LogWeightTpl Divide(const LogWeightTpl &w1, const LogWeightTpl &w2, DivideType typ = DIVIDE_ANY) { using Limits = FloatLimits; if (!w1.Member() || !w2.Member()) return LogWeightTpl::NoWeight(); const T f1 = w1.Value(); const T f2 = w2.Value(); if (f2 == Limits::PosInfinity()) { return Limits::NumberBad(); } else if (f1 == Limits::PosInfinity()) { return Limits::PosInfinity(); } else { return LogWeightTpl(f1 - f2); } } inline LogWeightTpl Divide(const LogWeightTpl &w1, const LogWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } inline LogWeightTpl Divide(const LogWeightTpl &w1, const LogWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } template inline LogWeightTpl Power(const LogWeightTpl &weight, V n) { if (n == 0) { return LogWeightTpl::One(); } else if (weight == LogWeightTpl::Zero()) { return LogWeightTpl::Zero(); } return LogWeightTpl(weight.Value() * n); } // Specializes the library-wide template to use the above implementation; rules // of function template instantiation require this be a full instantiation. template <> inline LogWeightTpl Power>( const LogWeightTpl &weight, size_t n) { return Power(weight, n); } template <> inline LogWeightTpl Power>( const LogWeightTpl &weight, size_t n) { return Power(weight, n); } // Specialization using the Kahan compensated summation. template class Adder> { public: using Weight = LogWeightTpl; explicit Adder(Weight w = Weight::Zero()) : sum_(w.Value()), c_(0.0) { } Weight Add(const Weight &w) { using Limits = FloatLimits; const T f = w.Value(); if (f == Limits::PosInfinity()) { return Sum(); } else if (sum_ == Limits::PosInfinity()) { sum_ = f; c_ = 0.0; } else if (f > sum_) { sum_ = internal::KahanLogSum(sum_, f, &c_); } else { sum_ = internal::KahanLogSum(f, sum_, &c_); } return Sum(); } Weight Sum() { return Weight(sum_); } void Reset(Weight w = Weight::Zero()) { sum_ = w.Value(); c_ = 0.0; } private: double sum_; double c_; // Kahan compensation. }; // MinMax semiring: (min, max, inf, -inf). template class MinMaxWeightTpl : public FloatWeightTpl { public: using typename FloatWeightTpl::ValueType; using FloatWeightTpl::Value; using ReverseWeight = MinMaxWeightTpl; using Limits = FloatLimits; MinMaxWeightTpl() : FloatWeightTpl() {} MinMaxWeightTpl(T f) : FloatWeightTpl(f) {} MinMaxWeightTpl(const MinMaxWeightTpl &weight) : FloatWeightTpl(weight) {} static const MinMaxWeightTpl &Zero() { static const MinMaxWeightTpl zero(Limits::PosInfinity()); return zero; } static const MinMaxWeightTpl &One() { static const MinMaxWeightTpl one(Limits::NegInfinity()); return one; } static const MinMaxWeightTpl &NoWeight() { static const MinMaxWeightTpl no_weight(Limits::NumberBad()); return no_weight; } static const string &Type() { static const string *const type = new string(string("minmax") + FloatWeightTpl::GetPrecisionString()); return *type; } // Fails for IEEE NaN. bool Member() const { return Value() == Value(); } MinMaxWeightTpl Quantize(float delta = kDelta) const { // If one of infinities, or a NaN. if (!Member() || Value() == Limits::NegInfinity() || Value() == Limits::PosInfinity()) { return *this; } else { return MinMaxWeightTpl(floor(Value() / delta + 0.5F) * delta); } } MinMaxWeightTpl Reverse() const { return *this; } static constexpr uint64 Properties() { return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath; } }; // Single-precision min-max weight. using MinMaxWeight = MinMaxWeightTpl; // Min. template inline MinMaxWeightTpl Plus(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { if (!w1.Member() || !w2.Member()) return MinMaxWeightTpl::NoWeight(); return w1.Value() < w2.Value() ? w1 : w2; } inline MinMaxWeightTpl Plus(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { return Plus(w1, w2); } inline MinMaxWeightTpl Plus(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { return Plus(w1, w2); } // Max. template inline MinMaxWeightTpl Times(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { if (!w1.Member() || !w2.Member()) return MinMaxWeightTpl::NoWeight(); return w1.Value() >= w2.Value() ? w1 : w2; } inline MinMaxWeightTpl Times(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { return Times(w1, w2); } inline MinMaxWeightTpl Times(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { return Times(w1, w2); } // Defined only for special cases. template inline MinMaxWeightTpl Divide(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2, DivideType typ = DIVIDE_ANY) { if (!w1.Member() || !w2.Member()) return MinMaxWeightTpl::NoWeight(); // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2. return w1.Value() >= w2.Value() ? w1 : FloatLimits::NumberBad(); } inline MinMaxWeightTpl Divide(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } inline MinMaxWeightTpl Divide(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } // Converts to tropical. template <> struct WeightConvert { TropicalWeight operator()(const LogWeight &w) const { return w.Value(); } }; template <> struct WeightConvert { TropicalWeight operator()(const Log64Weight &w) const { return w.Value(); } }; // Converts to log. template <> struct WeightConvert { LogWeight operator()(const TropicalWeight &w) const { return w.Value(); } }; template <> struct WeightConvert { LogWeight operator()(const Log64Weight &w) const { return w.Value(); } }; // Converts to log64. template <> struct WeightConvert { Log64Weight operator()(const TropicalWeight &w) const { return w.Value(); } }; template <> struct WeightConvert { Log64Weight operator()(const LogWeight &w) const { return w.Value(); } }; // This function object returns random integers chosen from [0, // num_random_weights). The boolean 'allow_zero' determines whether Zero() and // zero divisors should be returned in the random weight generation. This is // intended primary for testing. template class FloatWeightGenerate { public: explicit FloatWeightGenerate( bool allow_zero = true, const size_t num_random_weights = kNumRandomWeights) : allow_zero_(allow_zero), num_random_weights_(num_random_weights) {} Weight operator()() const { const int n = rand() % (num_random_weights_ + allow_zero_); // NOLINT if (allow_zero_ && n == num_random_weights_) return Weight::Zero(); return Weight(n); } private: // Permits Zero() and zero divisors. const bool allow_zero_; // Number of alternative random weights. const size_t num_random_weights_; }; template class WeightGenerate> : public FloatWeightGenerate> { public: using Weight = TropicalWeightTpl; using Generate = FloatWeightGenerate; explicit WeightGenerate(bool allow_zero = true, size_t num_random_weights = kNumRandomWeights) : Generate(allow_zero, num_random_weights) {} Weight operator()() const { return Weight(Generate::operator()()); } }; template class WeightGenerate> : public FloatWeightGenerate> { public: using Weight = LogWeightTpl; using Generate = FloatWeightGenerate; explicit WeightGenerate(bool allow_zero = true, size_t num_random_weights = kNumRandomWeights) : Generate(allow_zero, num_random_weights) {} Weight operator()() const { return Weight(Generate::operator()()); } }; // This function object returns random integers chosen from [0, // num_random_weights). The boolean 'allow_zero' determines whether Zero() and // zero divisors should be returned in the random weight generation. This is // intended primary for testing. template class WeightGenerate> { public: using Weight = MinMaxWeightTpl; explicit WeightGenerate(bool allow_zero = true, size_t num_random_weights = kNumRandomWeights) : allow_zero_(allow_zero), num_random_weights_(num_random_weights) {} Weight operator()() const { const int n = (rand() % // NOLINT (2 * num_random_weights_ + allow_zero_)) - num_random_weights_; if (allow_zero_ && n == num_random_weights_) { return Weight::Zero(); } else if (n == -num_random_weights_) { return Weight::One(); } else { return Weight(n); } } private: // Permits Zero() and zero divisors. const bool allow_zero_; // Number of alternative random weights. const size_t num_random_weights_; }; } // namespace fst #endif // FST_FLOAT_WEIGHT_H_