pdtreplace.cc 3.06 KB
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Converts an RTN represented by FSTs and non-terminal labels into a PDT.

#include <cstring>

#include <string>
#include <vector>

#include <fst/flags.h>

#include <fst/extensions/pdt/getters.h>
#include <fst/extensions/pdt/pdtscript.h>
#include <fst/util.h>
#include <fst/vector-fst.h>

DEFINE_string(pdt_parentheses, "", "PDT parenthesis label pairs");
DEFINE_string(pdt_parser_type, "left",
              "Construction method, one of: \"left\", \"left_sr\"");
DEFINE_int64(start_paren_labels, fst::kNoLabel,
             "Index to use for the first inserted parentheses; if not "
             "specified, the next available label beyond the highest output "
             "label is used");
DEFINE_string(left_paren_prefix, "(_", "Prefix to attach to SymbolTable "
              "labels for inserted left parentheses");
DEFINE_string(right_paren_prefix, ")_", "Prefix to attach to SymbolTable "
              "labels for inserted right parentheses");

void Cleanup(std::vector<fst::script::LabelFstClassPair> *pairs) {
  for (const auto &pair : *pairs) {
    delete pair.second;
  }
  pairs->clear();
}

int main(int argc, char **argv) {
  namespace s = fst::script;
  using fst::script::FstClass;
  using fst::script::VectorFstClass;
  using fst::PdtParserType;
  using fst::WriteLabelPairs;

  string usage = "Converts an RTN represented by FSTs";
  usage += " and non-terminal labels into PDT.\n\n  Usage: ";
  usage += argv[0];
  usage += " root.fst rootlabel [rule1.fst label1 ...] [out.fst]\n";

  std::set_new_handler(FailedNewHandler);
  SET_FLAGS(usage.c_str(), &argc, &argv, true);
  if (argc < 4) {
    ShowUsage();
    return 1;
  }

  const string in_name = argv[1];
  const string out_name = argc % 2 == 0 ? argv[argc - 1] : "";

  auto *ifst = FstClass::Read(in_name);
  if (!ifst) return 1;

  PdtParserType parser_type;
  if (!s::GetPdtParserType(FLAGS_pdt_parser_type, &parser_type)) {
    LOG(ERROR) << argv[0] << ": Unknown PDT parser type: "
               << FLAGS_pdt_parser_type;
    delete ifst;
    return 1;
  }

  std::vector<s::LabelFstClassPair> pairs;
  // Note that if the root label is beyond the range of the underlying FST's
  // labels, truncation will occur.
  const auto root = atoll(argv[2]);
  pairs.emplace_back(root, ifst);

  for (auto i = 3; i < argc - 1; i += 2) {
    ifst = FstClass::Read(argv[i]);
    if (!ifst) {
      Cleanup(&pairs);
      return 1;
    }
    // Note that if the root label is beyond the range of the underlying FST's
    // labels, truncation will occur.
    const auto label = atoll(argv[i + 1]);
    pairs.emplace_back(label, ifst);
  }

  VectorFstClass ofst(ifst->ArcType());
  std::vector<s::LabelPair> parens;
  s::PdtReplace(pairs, &ofst, &parens, root, parser_type,
                FLAGS_start_paren_labels, FLAGS_left_paren_prefix,
                FLAGS_right_paren_prefix);
  Cleanup(&pairs);

  if (!FLAGS_pdt_parentheses.empty()) {
    if (!WriteLabelPairs(FLAGS_pdt_parentheses, parens)) return 1;
  }

  ofst.Write(out_name);

  return 0;
}