// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Classes and functions to generate random paths through an FST. #ifndef FST_RANDGEN_H_ #define FST_RANDGEN_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace fst { // The RandGenFst class is roughly similar to ArcMapFst in that it takes two // template parameters denoting the input and output arc types. However, it also // takes an additional template parameter which specifies a sampler object which // samples (with replacement) arcs from an FST state. The sampler in turn takes // a template parameter for a selector object which actually chooses the arc. // // Arc selector functors are used to select a random transition given an FST // state s, returning a number N such that 0 <= N <= NumArcs(s). If N is // NumArcs(s), then the final weight is selected; otherwise the N-th arc is // selected. It is assumed these are not applied to any state which is neither // final nor has any arcs leaving it. // Randomly selects a transition using the uniform distribution. This class is // not thread-safe. template class UniformArcSelector { public: using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; // Constructs a selector with a non-deterministic seed. UniformArcSelector() : rand_(std::random_device()()) {} // Constructs a selector with a given seed. explicit UniformArcSelector(uint64 seed) : rand_(seed) {} size_t operator()(const Fst &fst, StateId s) const { const auto n = fst.NumArcs(s) + (fst.Final(s) != Weight::Zero()); return static_cast( std::uniform_int_distribution<>(0, n - 1)(rand_)); } private: mutable std::mt19937_64 rand_; }; // Randomly selects a transition w.r.t. the weights treated as negative log // probabilities after normalizing for the total weight leaving the state. Zero // transitions are disregarded. It assumed that Arc::Weight::Value() accesses // the floating point representation of the weight. This class is not // thread-safe. template class LogProbArcSelector { public: using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; // Constructs a selector with a non-deterministic seed. LogProbArcSelector() : rand_(std::random_device()()) {} // Constructs a selector with a given seed. explicit LogProbArcSelector(uint64 seed) : rand_(seed) {} size_t operator()(const Fst &fst, StateId s) const { // Finds total weight leaving state. auto sum = Log64Weight::Zero(); ArcIterator> aiter(fst, s); for (; !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); sum = Plus(sum, to_log_weight_(arc.weight)); } sum = Plus(sum, to_log_weight_(fst.Final(s))); const double threshold = std::uniform_real_distribution<>(0, exp(-sum.Value()))(rand_); auto p = Log64Weight::Zero(); size_t n = 0; for (aiter.Reset(); !aiter.Done(); aiter.Next(), ++n) { p = Plus(p, to_log_weight_(aiter.Value().weight)); if (exp(-p.Value()) > threshold) return n; } return n; } private: mutable std::mt19937_64 rand_; WeightConvert to_log_weight_; }; // Useful alias when using StdArc. using StdArcSelector = LogProbArcSelector; // Same as LogProbArcSelector but use CacheLogAccumulator to cache the weight // accumulation computations. This class is not thread-safe. template class FastLogProbArcSelector : public LogProbArcSelector { public: using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; using LogProbArcSelector::operator(); // Constructs a selector with a non-deterministic seed. FastLogProbArcSelector() : seed_(std::random_device()()), rand_(seed_) {} // Constructs a selector with a given seed. explicit FastLogProbArcSelector(uint64 seed) : seed_(seed), rand_(seed_) {} size_t operator()(const Fst &fst, StateId s, CacheLogAccumulator *accumulator) const { accumulator->SetState(s); ArcIterator> aiter(fst, s); // Finds total weight leaving state. const double sum = to_log_weight_(accumulator->Sum(fst.Final(s), &aiter, 0, fst.NumArcs(s))) .Value(); const double r = -log(std::uniform_real_distribution<>(0, 1)(rand_)); Weight w = from_log_weight_(r + sum); aiter.Reset(); return accumulator->LowerBound(w, &aiter); } uint64 Seed() const { return seed_; } private: const uint64 seed_; mutable std::mt19937_64 rand_; WeightConvert to_log_weight_; WeightConvert from_log_weight_; }; // Random path state info maintained by RandGenFst and passed to samplers. template struct RandState { using StateId = typename Arc::StateId; StateId state_id; // Current input FST state. size_t nsamples; // Number of samples to be sampled at this state. size_t length; // Length of path to this random state. size_t select; // Previous sample arc selection. const RandState *parent; // Previous random state on this path. explicit RandState(StateId state_id, size_t nsamples = 0, size_t length = 0, size_t select = 0, const RandState *parent = nullptr) : state_id(state_id), nsamples(nsamples), length(length), select(select), parent(parent) {} RandState() : RandState(kNoStateId) {} }; // This class, given an arc selector, samples, with replacement, multiple random // transitions from an FST's state. This is a generic version with a // straightforward use of the arc selector. Specializations may be defined for // arc selectors for greater efficiency or special behavior. template class ArcSampler { public: using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; // The max_length argument may be interpreted (or ignored) by a selector as // it chooses. This generic version interprets this literally. ArcSampler(const Fst &fst, const Selector &selector, int32 max_length = std::numeric_limits::max()) : fst_(fst), selector_(selector), max_length_(max_length) {} // Allow updating FST argument; pass only if changed. ArcSampler(const ArcSampler &sampler, const Fst *fst = nullptr) : fst_(fst ? *fst : sampler.fst_), selector_(sampler.selector_), max_length_(sampler.max_length_) { Reset(); } // Samples a fixed number of samples from the given state. The length argument // specifies the length of the path to the state. Returns true if the samples // were collected. No samples may be collected if either there are no // transitions leaving the state and the state is non-final, or if the path // length has been exceeded. Iterator members are provided to read the samples // in the order in which they were collected. bool Sample(const RandState &rstate) { sample_map_.clear(); if ((fst_.NumArcs(rstate.state_id) == 0 && fst_.Final(rstate.state_id) == Weight::Zero()) || rstate.length == max_length_) { Reset(); return false; } for (size_t i = 0; i < rstate.nsamples; ++i) { ++sample_map_[selector_(fst_, rstate.state_id)]; } Reset(); return true; } // More samples? bool Done() const { return sample_iter_ == sample_map_.end(); } // Gets the next sample. void Next() { ++sample_iter_; } std::pair Value() const { return *sample_iter_; } void Reset() { sample_iter_ = sample_map_.begin(); } bool Error() const { return false; } private: const Fst &fst_; const Selector &selector_; const int32 max_length_; // Stores (N, K) as described for Value(). std::map sample_map_; std::map::const_iterator sample_iter_; ArcSampler &operator=(const ArcSampler &) = delete; }; // Samples one sample of num_to_sample dimensions from a multinomial // distribution parameterized by a vector of probabilities. The result // container should be pre-initialized (e.g., an empty map or a zeroed vector // sized the same as the vector of probabilities. // probs.size()). template void OneMultinomialSample(const std::vector &probs, size_t num_to_sample, Result *result, RNG *rng) { // Left-over probability mass. double norm = 0; for (double p : probs) norm += p; // Left-over number of samples needed. for (size_t i = 0; i < probs.size(); ++i) { size_t num_sampled = 0; if (probs[i] > 0) { std::binomial_distribution<> d(num_to_sample, probs[i] / norm); num_sampled = d(*rng); } if (num_sampled != 0) (*result)[i] = num_sampled; norm -= probs[i]; num_to_sample -= num_sampled; } } // Specialization for FastLogProbArcSelector. template class ArcSampler> { public: using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; using Accumulator = CacheLogAccumulator; using Selector = FastLogProbArcSelector; ArcSampler(const Fst &fst, const Selector &selector, int32 max_length = std::numeric_limits::max()) : fst_(fst), selector_(selector), max_length_(max_length), accumulator_(new Accumulator()) { accumulator_->Init(fst); rng_.seed(selector_.Seed()); } ArcSampler(const ArcSampler &sampler, const Fst *fst = nullptr) : fst_(fst ? *fst : sampler.fst_), selector_(sampler.selector_), max_length_(sampler.max_length_) { if (fst) { accumulator_.reset(new Accumulator()); accumulator_->Init(*fst); } else { // Shallow copy. accumulator_.reset(new Accumulator(*sampler.accumulator_)); } } bool Sample(const RandState &rstate) { sample_map_.clear(); if ((fst_.NumArcs(rstate.state_id) == 0 && fst_.Final(rstate.state_id) == Weight::Zero()) || rstate.length == max_length_) { Reset(); return false; } if (fst_.NumArcs(rstate.state_id) + 1 < rstate.nsamples) { MultinomialSample(rstate); Reset(); return true; } for (size_t i = 0; i < rstate.nsamples; ++i) { ++sample_map_[selector_(fst_, rstate.state_id, accumulator_.get())]; } Reset(); return true; } bool Done() const { return sample_iter_ == sample_map_.end(); } void Next() { ++sample_iter_; } std::pair Value() const { return *sample_iter_; } void Reset() { sample_iter_ = sample_map_.begin(); } bool Error() const { return accumulator_->Error(); } private: using RNG = std::mt19937; // Sample according to the multinomial distribution of rstate.nsamples draws // from p_. void MultinomialSample(const RandState &rstate) { p_.clear(); for (ArcIterator> aiter(fst_, rstate.state_id); !aiter.Done(); aiter.Next()) { p_.push_back(exp(-to_log_weight_(aiter.Value().weight).Value())); } if (fst_.Final(rstate.state_id) != Weight::Zero()) { p_.push_back(exp(-to_log_weight_(fst_.Final(rstate.state_id)).Value())); } if (rstate.nsamples < std::numeric_limits::max()) { OneMultinomialSample(p_, rstate.nsamples, &sample_map_, &rng_); } else { for (size_t i = 0; i < p_.size(); ++i) { sample_map_[i] = ceil(p_[i] * rstate.nsamples); } } } const Fst &fst_; const Selector &selector_; const int32 max_length_; // Stores (N, K) for Value(). std::map sample_map_; std::map::const_iterator sample_iter_; std::unique_ptr accumulator_; RNG rng_; // Random number generator. std::vector p_; // Multinomial parameters. WeightConvert to_log_weight_; }; // Options for random path generation with RandGenFst. The template argument is // a sampler, typically the class ArcSampler. Ownership of the sampler is taken // by RandGenFst. template struct RandGenFstOptions : public CacheOptions { Sampler *sampler; // How to sample transitions at a state. int32 npath; // Number of paths to generate. bool weighted; // Is the output tree weighted by path count, or // is it just an unweighted DAG? bool remove_total_weight; // Remove total weight when output is weighted. RandGenFstOptions(const CacheOptions &opts, Sampler *sampler, int32 npath = 1, bool weighted = true, bool remove_total_weight = false) : CacheOptions(opts), sampler(sampler), npath(npath), weighted(weighted), remove_total_weight(remove_total_weight) {} }; namespace internal { // Implementation of RandGenFst. template class RandGenFstImpl : public CacheImpl { public: using FstImpl::SetType; using FstImpl::SetProperties; using FstImpl::SetInputSymbols; using FstImpl::SetOutputSymbols; using CacheBaseImpl>::PushArc; using CacheBaseImpl>::HasArcs; using CacheBaseImpl>::HasFinal; using CacheBaseImpl>::HasStart; using CacheBaseImpl>::SetArcs; using CacheBaseImpl>::SetFinal; using CacheBaseImpl>::SetStart; using Label = typename FromArc::Label; using StateId = typename FromArc::StateId; using FromWeight = typename FromArc::Weight; using ToWeight = typename ToArc::Weight; RandGenFstImpl(const Fst &fst, const RandGenFstOptions &opts) : CacheImpl(opts), fst_(fst.Copy()), sampler_(opts.sampler), npath_(opts.npath), weighted_(opts.weighted), remove_total_weight_(opts.remove_total_weight), superfinal_(kNoLabel) { SetType("randgen"); SetProperties( RandGenProperties(fst.Properties(kFstProperties, false), weighted_), kCopyProperties); SetInputSymbols(fst.InputSymbols()); SetOutputSymbols(fst.OutputSymbols()); } RandGenFstImpl(const RandGenFstImpl &impl) : CacheImpl(impl), fst_(impl.fst_->Copy(true)), sampler_(new Sampler(*impl.sampler_, fst_.get())), npath_(impl.npath_), weighted_(impl.weighted_), superfinal_(kNoLabel) { SetType("randgen"); SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(impl.InputSymbols()); SetOutputSymbols(impl.OutputSymbols()); } StateId Start() { if (!HasStart()) { const auto s = fst_->Start(); if (s == kNoStateId) return kNoStateId; SetStart(state_table_.size()); state_table_.emplace_back( new RandState(s, npath_, 0, 0, nullptr)); } return CacheImpl::Start(); } ToWeight Final(StateId s) { if (!HasFinal(s)) Expand(s); return CacheImpl::Final(s); } size_t NumArcs(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumArcs(s); } size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumInputEpsilons(s); } size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumOutputEpsilons(s); } uint64 Properties() const override { return Properties(kFstProperties); } // Sets error if found, and returns other FST impl properties. uint64 Properties(uint64 mask) const override { if ((mask & kError) && (fst_->Properties(kError, false) || sampler_->Error())) { SetProperties(kError, kError); } return FstImpl::Properties(mask); } void InitArcIterator(StateId s, ArcIteratorData *data) { if (!HasArcs(s)) Expand(s); CacheImpl::InitArcIterator(s, data); } // Computes the outgoing transitions from a state, creating new destination // states as needed. void Expand(StateId s) { if (s == superfinal_) { SetFinal(s, ToWeight::One()); SetArcs(s); return; } SetFinal(s, ToWeight::Zero()); const auto &rstate = *state_table_[s]; sampler_->Sample(rstate); ArcIterator> aiter(*fst_, rstate.state_id); const auto narcs = fst_->NumArcs(rstate.state_id); for (; !sampler_->Done(); sampler_->Next()) { const auto &sample_pair = sampler_->Value(); const auto pos = sample_pair.first; const auto count = sample_pair.second; double prob = static_cast(count) / rstate.nsamples; if (pos < narcs) { // Regular transition. aiter.Seek(sample_pair.first); const auto &aarc = aiter.Value(); const auto weight = weighted_ ? to_weight_(Log64Weight(-log(prob))) : ToWeight::One(); const ToArc barc(aarc.ilabel, aarc.olabel, weight, state_table_.size()); PushArc(s, barc); auto *nrstate = new RandState(aarc.nextstate, count, rstate.length + 1, pos, &rstate); state_table_.emplace_back(nrstate); } else { // Super-final transition. if (weighted_) { const auto weight = remove_total_weight_ ? to_weight_(Log64Weight(-log(prob))) : to_weight_(Log64Weight(-log(prob * npath_))); SetFinal(s, weight); } else { if (superfinal_ == kNoLabel) { superfinal_ = state_table_.size(); state_table_.emplace_back( new RandState(kNoStateId, 0, 0, 0, nullptr)); } for (size_t n = 0; n < count; ++n) { const ToArc barc(0, 0, ToWeight::One(), superfinal_); PushArc(s, barc); } } } } SetArcs(s); } private: const std::unique_ptr> fst_; std::unique_ptr sampler_; const int32 npath_; std::vector>> state_table_; const bool weighted_; bool remove_total_weight_; StateId superfinal_; WeightConvert to_weight_; }; } // namespace internal // FST class to randomly generate paths through an FST, with details controlled // by RandGenOptionsFst. Output format is a tree weighted by the path count. template class RandGenFst : public ImplToFst> { public: using Label = typename FromArc::Label; using StateId = typename FromArc::StateId; using Weight = typename FromArc::Weight; using Store = DefaultCacheStore; using State = typename Store::State; using Impl = internal::RandGenFstImpl; friend class ArcIterator>; friend class StateIterator>; RandGenFst(const Fst &fst, const RandGenFstOptions &opts) : ImplToFst(std::make_shared(fst, opts)) {} // See Fst<>::Copy() for doc. RandGenFst(const RandGenFst &fst, bool safe = false) : ImplToFst(fst, safe) {} // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc. RandGenFst *Copy(bool safe = false) const override { return new RandGenFst(*this, safe); } inline void InitStateIterator(StateIteratorData *data) const override; void InitArcIterator(StateId s, ArcIteratorData *data) const override { GetMutableImpl()->InitArcIterator(s, data); } private: using ImplToFst::GetImpl; using ImplToFst::GetMutableImpl; RandGenFst &operator=(const RandGenFst &) = delete; }; // Specialization for RandGenFst. template class StateIterator> : public CacheStateIterator> { public: explicit StateIterator(const RandGenFst &fst) : CacheStateIterator>( fst, fst.GetMutableImpl()) {} }; // Specialization for RandGenFst. template class ArcIterator> : public CacheArcIterator> { public: using StateId = typename FromArc::StateId; ArcIterator(const RandGenFst &fst, StateId s) : CacheArcIterator>( fst.GetMutableImpl(), s) { if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); } }; template inline void RandGenFst::InitStateIterator( StateIteratorData *data) const { data->base = new StateIterator>(*this); } // Options for random path generation. template struct RandGenOptions { const Selector &selector; // How an arc is selected at a state. int32 max_length; // Maximum path length. int32 npath; // Number of paths to generate. bool weighted; // Is the output tree weighted by path count, or // is it just an unweighted DAG? bool remove_total_weight; // Remove total weight when output is weighted? explicit RandGenOptions(const Selector &selector, int32 max_length = std::numeric_limits::max(), int32 npath = 1, bool weighted = false, bool remove_total_weight = false) : selector(selector), max_length(max_length), npath(npath), weighted(weighted), remove_total_weight(remove_total_weight) {} }; namespace internal { template class RandGenVisitor { public: using StateId = typename FromArc::StateId; using Weight = typename FromArc::Weight; explicit RandGenVisitor(MutableFst *ofst) : ofst_(ofst) {} void InitVisit(const Fst &ifst) { ifst_ = &ifst; ofst_->DeleteStates(); ofst_->SetInputSymbols(ifst.InputSymbols()); ofst_->SetOutputSymbols(ifst.OutputSymbols()); if (ifst.Properties(kError, false)) ofst_->SetProperties(kError, kError); path_.clear(); } constexpr bool InitState(StateId, StateId) const { return true; } bool TreeArc(StateId, const ToArc &arc) { if (ifst_->Final(arc.nextstate) == Weight::Zero()) { path_.push_back(arc); } else { OutputPath(); } return true; } bool BackArc(StateId, const FromArc &) { FSTERROR() << "RandGenVisitor: cyclic input"; ofst_->SetProperties(kError, kError); return false; } bool ForwardOrCrossArc(StateId, const FromArc &) { OutputPath(); return true; } void FinishState(StateId s, StateId p, const FromArc *) { if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) path_.pop_back(); } void FinishVisit() {} private: void OutputPath() { if (ofst_->Start() == kNoStateId) { const auto start = ofst_->AddState(); ofst_->SetStart(start); } auto src = ofst_->Start(); for (size_t i = 0; i < path_.size(); ++i) { const auto dest = ofst_->AddState(); const ToArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest); ofst_->AddArc(src, arc); src = dest; } ofst_->SetFinal(src, Weight::One()); } const Fst *ifst_; MutableFst *ofst_; std::vector path_; RandGenVisitor(const RandGenVisitor &) = delete; RandGenVisitor &operator=(const RandGenVisitor &) = delete; }; } // namespace internal // Randomly generate paths through an FST; details controlled by // RandGenOptions. template void RandGen(const Fst &ifst, MutableFst *ofst, const RandGenOptions &opts) { using State = typename ToArc::StateId; using Weight = typename ToArc::Weight; using Sampler = ArcSampler; auto *sampler = new Sampler(ifst, opts.selector, opts.max_length); RandGenFstOptions fopts(CacheOptions(true, 0), sampler, opts.npath, opts.weighted, opts.remove_total_weight); RandGenFst rfst(ifst, fopts); if (opts.weighted) { *ofst = rfst; } else { internal::RandGenVisitor rand_visitor(ofst); DfsVisit(rfst, &rand_visitor); } } // Randomly generate a path through an FST with the uniform distribution // over the transitions. template void RandGen(const Fst &ifst, MutableFst *ofst) { const UniformArcSelector uniform_selector; RandGenOptions> opts(uniform_selector); RandGen(ifst, ofst, opts); } } // namespace fst #endif // FST_RANDGEN_H_