// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Convenience file for including all PDT operations at once, and/or // registering them for new arc types. #ifndef FST_EXTENSIONS_PDT_PDTSCRIPT_H_ #define FST_EXTENSIONS_PDT_PDTSCRIPT_H_ #include #include #include #include #include // for ComposeOptions #include #include #include #include #include #include #include #include #include #include namespace fst { namespace script { using PdtComposeArgs = std::tuple &, MutableFstClass *, const PdtComposeOptions &, bool>; template void PdtCompose(PdtComposeArgs *args) { const Fst &ifst1 = *(std::get<0>(*args).GetFst()); const Fst &ifst2 = *(std::get<1>(*args).GetFst()); MutableFst *ofst = std::get<3>(*args)->GetMutableFst(); // In case Arc::Label is not the same as FstClass::Label, we make a // copy. Truncation may occur if FstClass::Label has more precision than // Arc::Label. std::vector> typed_parens( std::get<2>(*args).size()); std::copy(std::get<2>(*args).begin(), std::get<2>(*args).end(), typed_parens.begin()); if (std::get<5>(*args)) { Compose(ifst1, typed_parens, ifst2, ofst, std::get<4>(*args)); } else { Compose(ifst1, ifst2, typed_parens, ofst, std::get<4>(*args)); } } void PdtCompose(const FstClass &ifst1, const FstClass &ifst2, const std::vector &parens, MutableFstClass *ofst, const PdtComposeOptions &opts, bool left_pdt); struct PdtExpandOptions { bool connect; bool keep_parentheses; const WeightClass &weight_threshold; PdtExpandOptions(bool c, bool k, const WeightClass &w) : connect(c), keep_parentheses(k), weight_threshold(w) {} }; using PdtExpandArgs = std::tuple &, MutableFstClass *, const PdtExpandOptions &>; template void PdtExpand(PdtExpandArgs *args) { const Fst &fst = *(std::get<0>(*args).GetFst()); MutableFst *ofst = std::get<2>(*args)->GetMutableFst(); // In case Arc::Label is not the same as FstClass::Label, we make a // copy. Truncation may occur if FstClass::Label has more precision than // Arc::Label. std::vector> typed_parens( std::get<1>(*args).size()); std::copy(std::get<1>(*args).begin(), std::get<1>(*args).end(), typed_parens.begin()); Expand(fst, typed_parens, ofst, fst::PdtExpandOptions( std::get<3>(*args).connect, std::get<3>(*args).keep_parentheses, *(std::get<3>(*args) .weight_threshold.GetWeight()))); } void PdtExpand(const FstClass &ifst, const std::vector &parens, MutableFstClass *ofst, const PdtExpandOptions &opts); void PdtExpand(const FstClass &ifst, const std::vector &parens, MutableFstClass *ofst, bool connect, bool keep_parentheses, const WeightClass &weight_threshold); using PdtReplaceArgs = std::tuple &, MutableFstClass *, std::vector *, int64, PdtParserType, int64, const string &, const string &>; template void PdtReplace(PdtReplaceArgs *args) { const auto &untyped_pairs = std::get<0>(*args); auto size = untyped_pairs.size(); std::vector *>> typed_pairs( size); for (size_t i = 0; i < size; ++i) { typed_pairs[i].first = untyped_pairs[i].first; typed_pairs[i].second = untyped_pairs[i].second->GetFst(); } MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); std::vector> typed_parens; const PdtReplaceOptions opts(std::get<3>(*args), std::get<4>(*args), std::get<5>(*args), std::get<6>(*args), std::get<7>(*args)); Replace(typed_pairs, ofst, &typed_parens, opts); // Copies typed parens into arg3. std::get<2>(*args)->resize(typed_parens.size()); std::copy(typed_parens.begin(), typed_parens.end(), std::get<2>(*args)->begin()); } void PdtReplace(const std::vector &pairs, MutableFstClass *ofst, std::vector *parens, int64 root, PdtParserType parser_type = PDT_LEFT_PARSER, int64 start_paren_labels = kNoLabel, const string &left_paren_prefix = "(_", const string &right_paren_prefix = "_)"); using PdtReverseArgs = std::tuple &, MutableFstClass *>; template void PdtReverse(PdtReverseArgs *args) { const Fst &fst = *(std::get<0>(*args).GetFst()); MutableFst *ofst = std::get<2>(*args)->GetMutableFst(); // In case Arc::Label is not the same as FstClass::Label, we make a // copy. Truncation may occur if FstClass::Label has more precision than // Arc::Label. std::vector> typed_parens( std::get<1>(*args).size()); std::copy(std::get<1>(*args).begin(), std::get<1>(*args).end(), typed_parens.begin()); Reverse(fst, typed_parens, ofst); } void PdtReverse(const FstClass &ifst, const std::vector &, MutableFstClass *ofst); // PDT SHORTESTPATH struct PdtShortestPathOptions { QueueType queue_type; bool keep_parentheses; bool path_gc; PdtShortestPathOptions(QueueType qt = FIFO_QUEUE, bool kp = false, bool gc = true) : queue_type(qt), keep_parentheses(kp), path_gc(gc) {} }; using PdtShortestPathArgs = std::tuple &, MutableFstClass *, const PdtShortestPathOptions &>; template void PdtShortestPath(PdtShortestPathArgs *args) { const Fst &fst = *(std::get<0>(*args).GetFst()); MutableFst *ofst = std::get<2>(*args)->GetMutableFst(); const PdtShortestPathOptions &opts = std::get<3>(*args); // In case Arc::Label is not the same as FstClass::Label, we make a // copy. Truncation may occur if FstClass::Label has more precision than // Arc::Label. std::vector> typed_parens( std::get<1>(*args).size()); std::copy(std::get<1>(*args).begin(), std::get<1>(*args).end(), typed_parens.begin()); switch (opts.queue_type) { default: FSTERROR() << "Unknown queue type: " << opts.queue_type; case FIFO_QUEUE: { using Queue = FifoQueue; fst::PdtShortestPathOptions spopts(opts.keep_parentheses, opts.path_gc); ShortestPath(fst, typed_parens, ofst, spopts); return; } case LIFO_QUEUE: { using Queue = LifoQueue; fst::PdtShortestPathOptions spopts(opts.keep_parentheses, opts.path_gc); ShortestPath(fst, typed_parens, ofst, spopts); return; } case STATE_ORDER_QUEUE: { using Queue = StateOrderQueue; fst::PdtShortestPathOptions spopts(opts.keep_parentheses, opts.path_gc); ShortestPath(fst, typed_parens, ofst, spopts); return; } } } void PdtShortestPath(const FstClass &ifst, const std::vector &parens, MutableFstClass *ofst, const PdtShortestPathOptions &opts = PdtShortestPathOptions()); // PRINT INFO using PrintPdtInfoArgs = std::pair &>; template void PrintPdtInfo(PrintPdtInfoArgs *args) { const Fst &fst = *(std::get<0>(*args).GetFst()); // In case Arc::Label is not the same as FstClass::Label, we make a // copy. Truncation may occur if FstClass::Label has more precision than // Arc::Label. std::vector> typed_parens( std::get<1>(*args).size()); std::copy(std::get<1>(*args).begin(), std::get<1>(*args).end(), typed_parens.begin()); PdtInfo pdtinfo(fst, typed_parens); PrintPdtInfo(pdtinfo); } void PrintPdtInfo(const FstClass &ifst, const std::vector &parens); } // namespace script } // namespace fst #define REGISTER_FST_PDT_OPERATIONS(ArcType) \ REGISTER_FST_OPERATION(PdtCompose, ArcType, PdtComposeArgs); \ REGISTER_FST_OPERATION(PdtExpand, ArcType, PdtExpandArgs); \ REGISTER_FST_OPERATION(PdtReplace, ArcType, PdtReplaceArgs); \ REGISTER_FST_OPERATION(PdtReverse, ArcType, PdtReverseArgs); \ REGISTER_FST_OPERATION(PdtShortestPath, ArcType, PdtShortestPathArgs); \ REGISTER_FST_OPERATION(PrintPdtInfo, ArcType, PrintPdtInfoArgs) #endif // FST_EXTENSIONS_PDT_PDTSCRIPT_H_