shortest-distance.h 6.84 KB
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.

#ifndef FST_SCRIPT_SHORTEST_DISTANCE_H_
#define FST_SCRIPT_SHORTEST_DISTANCE_H_

#include <tuple>
#include <vector>

#include <fst/queue.h>
#include <fst/shortest-distance.h>
#include <fst/script/fst-class.h>
#include <fst/script/prune.h>
#include <fst/script/script-impl.h>
#include <fst/script/weight-class.h>

namespace fst {
namespace script {

enum ArcFilterType {
  ANY_ARC_FILTER,
  EPSILON_ARC_FILTER,
  INPUT_EPSILON_ARC_FILTER,
  OUTPUT_EPSILON_ARC_FILTER
};

struct ShortestDistanceOptions {
  const QueueType queue_type;
  const ArcFilterType arc_filter_type;
  const int64 source;
  const float delta;

  ShortestDistanceOptions(QueueType queue_type, ArcFilterType arc_filter_type,
                          int64 source, float delta)
      : queue_type(queue_type),
        arc_filter_type(arc_filter_type),
        source(source),
        delta(delta) {}
};

namespace internal {

// Code to implement switching on queue and arc filter types.

template <class Arc, class Queue, class ArcFilter>
struct QueueConstructor {
  using Weight = typename Arc::Weight;

  static Queue *Construct(const Fst<Arc> &, const std::vector<Weight> *) {
    return new Queue();
  }
};

// Specializations to support queues with different constructors.

template <class Arc, class ArcFilter>
struct QueueConstructor<Arc, AutoQueue<typename Arc::StateId>, ArcFilter> {
  using StateId = typename Arc::StateId;
  using Weight = typename Arc::Weight;

  //  template<class Arc, class ArcFilter>
  static AutoQueue<StateId> *Construct(const Fst<Arc> &fst,
                                       const std::vector<Weight> *distance) {
    return new AutoQueue<StateId>(fst, distance, ArcFilter());
  }
};

template <class Arc, class ArcFilter>
struct QueueConstructor<
    Arc, NaturalShortestFirstQueue<typename Arc::StateId, typename Arc::Weight>,
    ArcFilter> {
  using StateId = typename Arc::StateId;
  using Weight = typename Arc::Weight;

  static NaturalShortestFirstQueue<StateId, Weight> *Construct(
      const Fst<Arc> &, const std::vector<Weight> *distance) {
    return new NaturalShortestFirstQueue<StateId, Weight>(*distance);
  }
};

template <class Arc, class ArcFilter>
struct QueueConstructor<Arc, TopOrderQueue<typename Arc::StateId>, ArcFilter> {
  using StateId = typename Arc::StateId;
  using Weight = typename Arc::Weight;

  static TopOrderQueue<StateId> *Construct(const Fst<Arc> &fst,
                                           const std::vector<Weight> *) {
    return new TopOrderQueue<StateId>(fst, ArcFilter());
  }
};

template <class Arc, class Queue, class ArcFilter>
void ShortestDistance(const Fst<Arc> &fst,
                      std::vector<typename Arc::Weight> *distance,
                      const ShortestDistanceOptions &opts) {
  std::unique_ptr<Queue> queue(
      QueueConstructor<Arc, Queue, ArcFilter>::Construct(fst, distance));
  const fst::ShortestDistanceOptions<Arc, Queue, ArcFilter> sopts(
      queue.get(), ArcFilter(), opts.source, opts.delta);
  ShortestDistance(fst, distance, sopts);
}

template <class Arc, class Queue>
void ShortestDistance(const Fst<Arc> &fst,
                      std::vector<typename Arc::Weight> *distance,
                      const ShortestDistanceOptions &opts) {
  switch (opts.arc_filter_type) {
    case ANY_ARC_FILTER: {
      ShortestDistance<Arc, Queue, AnyArcFilter<Arc>>(fst, distance, opts);
      return;
    }
    case EPSILON_ARC_FILTER: {
      ShortestDistance<Arc, Queue, EpsilonArcFilter<Arc>>(fst, distance, opts);
      return;
    }
    case INPUT_EPSILON_ARC_FILTER: {
      ShortestDistance<Arc, Queue, InputEpsilonArcFilter<Arc>>(fst, distance,
                                                               opts);
      return;
    }
    case OUTPUT_EPSILON_ARC_FILTER: {
      ShortestDistance<Arc, Queue, OutputEpsilonArcFilter<Arc>>(fst, distance,
                                                                opts);
      return;
    }
    default: {
      FSTERROR() << "ShortestDistance: Unknown arc filter type: "
                 << opts.arc_filter_type;
      distance->clear();
      distance->resize(1, Arc::Weight::NoWeight());
      return;
    }
  }
}

}  // namespace internal

using ShortestDistanceArgs1 =
    std::tuple<const FstClass &, std::vector<WeightClass> *,
               const ShortestDistanceOptions &>;

template <class Arc>
void ShortestDistance(ShortestDistanceArgs1 *args) {
  using StateId = typename Arc::StateId;
  using Weight = typename Arc::Weight;
  const Fst<Arc> &fst = *(std::get<0>(*args).GetFst<Arc>());
  const auto &opts = std::get<2>(*args);
  std::vector<Weight> typed_distance;
  switch (opts.queue_type) {
    case AUTO_QUEUE: {
      internal::ShortestDistance<Arc, AutoQueue<StateId>>(fst, &typed_distance,
                                                          opts);
      break;
    }
    case FIFO_QUEUE: {
      internal::ShortestDistance<Arc, FifoQueue<StateId>>(fst, &typed_distance,
                                                          opts);
      break;
    }
    case LIFO_QUEUE: {
      internal::ShortestDistance<Arc, LifoQueue<StateId>>(fst, &typed_distance,
                                                          opts);
      break;
    }
    case SHORTEST_FIRST_QUEUE: {
      internal::ShortestDistance<Arc,
                                 NaturalShortestFirstQueue<StateId, Weight>>(
          fst, &typed_distance, opts);
      break;
    }
    case STATE_ORDER_QUEUE: {
      internal::ShortestDistance<Arc, StateOrderQueue<StateId>>(
          fst, &typed_distance, opts);
      break;
    }
    case TOP_ORDER_QUEUE: {
      internal::ShortestDistance<Arc, TopOrderQueue<StateId>>(
          fst, &typed_distance, opts);
      break;
    }
    default: {
      FSTERROR() << "ShortestDistance: Unknown queue type: " << opts.queue_type;
      typed_distance.clear();
      typed_distance.resize(1, Arc::Weight::NoWeight());
      break;
    }
  }
  internal::CopyWeights(typed_distance, std::get<1>(*args));
}

using ShortestDistanceArgs2 =
    std::tuple<const FstClass &, std::vector<WeightClass> *, bool, double>;

template <class Arc>
void ShortestDistance(ShortestDistanceArgs2 *args) {
  using Weight = typename Arc::Weight;
  const Fst<Arc> &fst = *(std::get<0>(*args).GetFst<Arc>());
  std::vector<Weight> typed_distance;
  ShortestDistance(fst, &typed_distance, std::get<2>(*args),
                   std::get<3>(*args));
  internal::CopyWeights(typed_distance, std::get<1>(*args));
}

void ShortestDistance(const FstClass &fst, std::vector<WeightClass> *distance,
                      const ShortestDistanceOptions &opts);

void ShortestDistance(const FstClass &ifst, std::vector<WeightClass> *distance,
                      bool reverse = false,
                      double delta = fst::kShortestDelta);

}  // namespace script
}  // namespace fst

#endif  // FST_SCRIPT_SHORTEST_DISTANCE_H_