weight-class.cc 2.88 KB
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.

#include <fst/script/weight-class.h>

namespace fst {
namespace script {

REGISTER_FST_WEIGHT(StdArc::Weight);
REGISTER_FST_WEIGHT(LogArc::Weight);
REGISTER_FST_WEIGHT(Log64Arc::Weight);

WeightClass::WeightClass(const string &weight_type, const string &weight_str) {
  WeightClassRegister *reg = WeightClassRegister::GetRegister();
  StrToWeightImplBaseT stw = reg->GetEntry(weight_type);
  if (!stw) {
    FSTERROR() << "Unknown weight type: " << weight_type;
    impl_.reset();
    return;
  }
  impl_.reset(stw(weight_str, "WeightClass", 0));
}

WeightClass WeightClass::Zero(const string &weight_type) {
  return WeightClass(weight_type, __ZERO__);
}

WeightClass WeightClass::One(const string &weight_type) {
  return WeightClass(weight_type, __ONE__);
}

WeightClass WeightClass::NoWeight(const string &weight_type) {
  return WeightClass(weight_type, __NOWEIGHT__);
}

bool WeightClass::WeightTypesMatch(const WeightClass &other,
                                   const string &op_name) const {
  if (Type() != other.Type()) {
    FSTERROR() << "Weights with non-matching types passed to " << op_name
               << ": " << Type() << " and " << other.Type();
    return false;
  }
  return true;
}

bool operator==(const WeightClass &lhs, const WeightClass &rhs) {
  const auto *lhs_impl = lhs.GetImpl();
  const auto *rhs_impl = rhs.GetImpl();
  if (!(lhs_impl && rhs_impl && lhs.WeightTypesMatch(rhs, "operator=="))) {
    return false;
  }
  return *lhs_impl == *rhs_impl;
}

bool operator!=(const WeightClass &lhs, const WeightClass &rhs) {
  return !(lhs == rhs);
}

WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs) {
  const auto *rhs_impl = rhs.GetImpl();
  if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Plus"))) {
    return WeightClass();
  }
  WeightClass result(lhs);
  result.GetImpl()->PlusEq(*rhs_impl);
  return result;
}

WeightClass Times(const WeightClass &lhs, const WeightClass &rhs) {
  const auto *rhs_impl = rhs.GetImpl();
  if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Plus"))) {
    return WeightClass();
  }
  WeightClass result(lhs);
  result.GetImpl()->TimesEq(*rhs_impl);
  return result;
}

WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs) {
  const auto *rhs_impl = rhs.GetImpl();
  if (!(lhs.GetImpl() && rhs_impl && lhs.WeightTypesMatch(rhs, "Divide"))) {
    return WeightClass();
  }
  WeightClass result(lhs);
  result.GetImpl()->DivideEq(*rhs_impl);
  return result;
}

WeightClass Power(const WeightClass &weight, size_t n) {
  if (!weight.GetImpl()) return WeightClass();
  WeightClass result(weight);
  result.GetImpl()->PowerEq(n);
  return result;
}

std::ostream &operator<<(std::ostream &ostrm, const WeightClass &weight) {
  weight.impl_->Print(&ostrm);
  return ostrm;
}

}  // namespace script
}  // namespace fst