// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Functions and classes to relabel an FST (either on input or output). #ifndef FST_RELABEL_H_ #define FST_RELABEL_H_ #include #include #include #include #include #include #include #include namespace fst { // Relabels either the input labels or output labels. The old to // new labels are specified using a vector of std::pair. // Any label associations not specified are assumed to be identity // mapping. The destination labels must be valid labels (e.g., not kNoLabel). template void Relabel( MutableFst *fst, const std::vector> &ipairs, const std::vector> &opairs) { using Label = typename Arc::Label; const auto props = fst->Properties(kFstProperties, false); // Constructs label-to-label maps. std::unordered_map input_map; for (auto &ipair : ipairs) input_map[ipair.first] = ipair.second; std::unordered_map output_map; for (auto &opair : opairs) output_map[opair.first] = opair.second; for (StateIterator> siter(*fst); !siter.Done(); siter.Next()) { for (MutableArcIterator> aiter(fst, siter.Value()); !aiter.Done(); aiter.Next()) { auto arc = aiter.Value(); // Relabels input. auto it = input_map.find(arc.ilabel); if (it != input_map.end()) { if (it->second == kNoLabel) { FSTERROR() << "Input symbol ID " << arc.ilabel << " missing from target vocabulary"; fst->SetProperties(kError, kError); return; } arc.ilabel = it->second; } // Relabels output. it = output_map.find(arc.olabel); if (it != output_map.end()) { if (it->second == kNoLabel) { FSTERROR() << "Output symbol id " << arc.olabel << " missing from target vocabulary"; fst->SetProperties(kError, kError); return; } arc.olabel = it->second; } aiter.SetValue(arc); } } fst->SetProperties(RelabelProperties(props), kFstProperties); } // Relabels either the input labels or output labels. The old to // new labels are specified using pairs of old and new symbol tables. // The tables must contain (at least) all labels on the appropriate side of the // FST. If the 'unknown_i(o)symbol' is non-empty, it is used to label any // missing symbol in new_i(o)symbols table. template void Relabel(MutableFst *fst, const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, const string &unknown_isymbol, bool attach_new_isymbols, const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, const string &unknown_osymbol, bool attach_new_osymbols) { using Label = typename Arc::Label; // Constructs vectors of input-side label pairs. std::vector> ipairs; if (old_isymbols && new_isymbols) { size_t num_missing_syms = 0; Label unknown_ilabel = kNoLabel; if (!unknown_isymbol.empty()) { unknown_ilabel = new_isymbols->Find(unknown_isymbol); if (unknown_ilabel == kNoLabel) { VLOG(1) << "Input symbol '" << unknown_isymbol << "' missing from target symbol table"; ++num_missing_syms; } } for (SymbolTableIterator siter(*old_isymbols); !siter.Done(); siter.Next()) { const auto old_index = siter.Value(); const auto symbol = siter.Symbol(); auto new_index = new_isymbols->Find(siter.Symbol()); if (new_index == kNoLabel) { if (unknown_ilabel != kNoLabel) { new_index = unknown_ilabel; } else { VLOG(1) << "Input symbol ID " << old_index << " symbol '" << symbol << "' missing from target symbol table"; ++num_missing_syms; } } ipairs.push_back(std::make_pair(old_index, new_index)); } if (num_missing_syms > 0) { LOG(WARNING) << "Target symbol table missing: " << num_missing_syms << " input symbols"; } if (attach_new_isymbols) fst->SetInputSymbols(new_isymbols); } // Constructs vectors of output-side label pairs. std::vector> opairs; if (old_osymbols && new_osymbols) { size_t num_missing_syms = 0; Label unknown_olabel = kNoLabel; if (!unknown_osymbol.empty()) { unknown_olabel = new_osymbols->Find(unknown_osymbol); if (unknown_olabel == kNoLabel) { VLOG(1) << "Output symbol '" << unknown_osymbol << "' missing from target symbol table"; ++num_missing_syms; } } for (SymbolTableIterator siter(*old_osymbols); !siter.Done(); siter.Next()) { const auto old_index = siter.Value(); const auto symbol = siter.Symbol(); auto new_index = new_osymbols->Find(siter.Symbol()); if (new_index == kNoLabel) { if (unknown_olabel != kNoLabel) { new_index = unknown_olabel; } else { VLOG(1) << "Output symbol ID " << old_index << " symbol '" << symbol << "' missing from target symbol table"; ++num_missing_syms; } } opairs.push_back(std::make_pair(old_index, new_index)); } if (num_missing_syms > 0) { LOG(WARNING) << "Target symbol table missing: " << num_missing_syms << " output symbols"; } if (attach_new_osymbols) fst->SetOutputSymbols(new_osymbols); } // Calls relabel using vector of relabel pairs. Relabel(fst, ipairs, opairs); } // Same as previous but no special allowance for unknown symbols. Kept // for backward compat. template void Relabel(MutableFst *fst, const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, bool attach_new_isymbols, const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, bool attach_new_osymbols) { Relabel(fst, old_isymbols, new_isymbols, "" /* no unknown isymbol */, attach_new_isymbols, old_osymbols, new_osymbols, "" /* no unknown ioymbol */, attach_new_osymbols); } // Relabels either the input labels or output labels. The old to // new labels are specified using symbol tables. Any label associations not // specified are assumed to be identity mapping. template void Relabel(MutableFst *fst, const SymbolTable *new_isymbols, const SymbolTable *new_osymbols) { Relabel(fst, fst->InputSymbols(), new_isymbols, true, fst->OutputSymbols(), new_osymbols, true); } using RelabelFstOptions = CacheOptions; template class RelabelFst; namespace internal { // Relabels an FST from one symbol set to another. Relabeling can either be on // input or output space. RelabelFst implements a delayed version of the // relabel. Arcs are relabeled on the fly and not cached; i.e., each request is // recomputed. template class RelabelFstImpl : public CacheImpl { public: using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; using Store = DefaultCacheStore; using State = typename Store::State; using FstImpl::SetType; using FstImpl::SetProperties; using FstImpl::WriteHeader; using FstImpl::SetInputSymbols; using FstImpl::SetOutputSymbols; using CacheImpl::PushArc; using CacheImpl::HasArcs; using CacheImpl::HasFinal; using CacheImpl::HasStart; using CacheImpl::SetArcs; using CacheImpl::SetFinal; using CacheImpl::SetStart; friend class StateIterator>; RelabelFstImpl(const Fst &fst, const std::vector> &ipairs, const std::vector> &opairs, const RelabelFstOptions &opts) : CacheImpl(opts), fst_(fst.Copy()), relabel_input_(false), relabel_output_(false) { SetProperties(RelabelProperties(fst.Properties(kCopyProperties, false))); SetType("relabel"); // Creates input label map. if (!ipairs.empty()) { for (auto &ipair : ipairs) input_map_[ipair.first] = ipair.second; relabel_input_ = true; } // Creates output label map. if (!opairs.empty()) { for (auto &opair : opairs) output_map_[opair.first] = opair.second; relabel_output_ = true; } } RelabelFstImpl(const Fst &fst, const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts) : CacheImpl(opts), fst_(fst.Copy()), relabel_input_(false), relabel_output_(false) { SetType("relabel"); SetProperties(RelabelProperties(fst.Properties(kCopyProperties, false))); SetInputSymbols(old_isymbols); SetOutputSymbols(old_osymbols); if (old_isymbols && new_isymbols && old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) { for (SymbolTableIterator siter(*old_isymbols); !siter.Done(); siter.Next()) { input_map_[siter.Value()] = new_isymbols->Find(siter.Symbol()); } SetInputSymbols(new_isymbols); relabel_input_ = true; } if (old_osymbols && new_osymbols && old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) { for (SymbolTableIterator siter(*old_osymbols); !siter.Done(); siter.Next()) { output_map_[siter.Value()] = new_osymbols->Find(siter.Symbol()); } SetOutputSymbols(new_osymbols); relabel_output_ = true; } } RelabelFstImpl(const RelabelFstImpl &impl) : CacheImpl(impl), fst_(impl.fst_->Copy(true)), input_map_(impl.input_map_), output_map_(impl.output_map_), relabel_input_(impl.relabel_input_), relabel_output_(impl.relabel_output_) { SetType("relabel"); SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(impl.InputSymbols()); SetOutputSymbols(impl.OutputSymbols()); } StateId Start() { if (!HasStart()) SetStart(fst_->Start()); return CacheImpl::Start(); } Weight Final(StateId s) { if (!HasFinal(s)) SetFinal(s, fst_->Final(s)); return CacheImpl::Final(s); } size_t NumArcs(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumArcs(s); } size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumInputEpsilons(s); } size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumOutputEpsilons(s); } uint64 Properties() const override { return Properties(kFstProperties); } // Sets error if found, and returns other FST impl properties. uint64 Properties(uint64 mask) const override { if ((mask & kError) && fst_->Properties(kError, false)) { SetProperties(kError, kError); } return FstImpl::Properties(mask); } void InitArcIterator(StateId s, ArcIteratorData *data) { if (!HasArcs(s)) Expand(s); CacheImpl::InitArcIterator(s, data); } void Expand(StateId s) { for (ArcIterator> aiter(*fst_, s); !aiter.Done(); aiter.Next()) { auto arc = aiter.Value(); if (relabel_input_) { auto it = input_map_.find(arc.ilabel); if (it != input_map_.end()) arc.ilabel = it->second; } if (relabel_output_) { auto it = output_map_.find(arc.olabel); if (it != output_map_.end()) { arc.olabel = it->second; } } PushArc(s, arc); } SetArcs(s); } private: std::unique_ptr> fst_; std::unordered_map input_map_; std::unordered_map output_map_; bool relabel_input_; bool relabel_output_; }; } // namespace internal // This class attaches interface to implementation and handles // reference counting, delegating most methods to ImplToFst. template class RelabelFst : public ImplToFst> { public: using Arc = A; using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; using Store = DefaultCacheStore; using State = typename Store::State; using Impl = internal::RelabelFstImpl; friend class ArcIterator>; friend class StateIterator>; RelabelFst(const Fst &fst, const std::vector> &ipairs, const std::vector> &opairs, const RelabelFstOptions &opts = RelabelFstOptions()) : ImplToFst(std::make_shared(fst, ipairs, opairs, opts)) {} RelabelFst(const Fst &fst, const SymbolTable *new_isymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts = RelabelFstOptions()) : ImplToFst( std::make_shared(fst, fst.InputSymbols(), new_isymbols, fst.OutputSymbols(), new_osymbols, opts)) {} RelabelFst(const Fst &fst, const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts = RelabelFstOptions()) : ImplToFst(std::make_shared(fst, old_isymbols, new_isymbols, old_osymbols, new_osymbols, opts)) {} // See Fst<>::Copy() for doc. RelabelFst(const RelabelFst &fst, bool safe = false) : ImplToFst(fst, safe) {} // Gets a copy of this RelabelFst. See Fst<>::Copy() for further doc. RelabelFst *Copy(bool safe = false) const override { return new RelabelFst(*this, safe); } void InitStateIterator(StateIteratorData *data) const override; void InitArcIterator(StateId s, ArcIteratorData *data) const override { return GetMutableImpl()->InitArcIterator(s, data); } private: using ImplToFst::GetImpl; using ImplToFst::GetMutableImpl; RelabelFst &operator=(const RelabelFst &) = delete; }; // Specialization for RelabelFst. template class StateIterator> : public StateIteratorBase { public: using StateId = typename Arc::StateId; explicit StateIterator(const RelabelFst &fst) : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {} bool Done() const final { return siter_.Done(); } StateId Value() const final { return s_; } void Next() final { if (!siter_.Done()) { ++s_; siter_.Next(); } } void Reset() final { s_ = 0; siter_.Reset(); } private: const internal::RelabelFstImpl* impl_; StateIterator> siter_; StateId s_; StateIterator(const StateIterator &) = delete; StateIterator &operator=(const StateIterator &) = delete; }; // Specialization for RelabelFst. template class ArcIterator> : public CacheArcIterator> { public: using StateId = typename Arc::StateId; ArcIterator(const RelabelFst &fst, StateId s) : CacheArcIterator>(fst.GetMutableImpl(), s) { if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); } }; template inline void RelabelFst::InitStateIterator( StateIteratorData *data) const { data->base = new StateIterator>(*this); } // Useful alias when using StdArc. using StdRelabelFst = RelabelFst; } // namespace fst #endif // FST_RELABEL_H_