Blame view
tools/openfst-1.6.7/src/lib/symbol-table-ops.cc
4.09 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // #include <fst/symbol-table-ops.h> #include <string> namespace fst { SymbolTable *MergeSymbolTable(const SymbolTable &left, const SymbolTable &right, bool *right_relabel_output) { // MergeSymbolTable detects several special cases. It will return a reference // copied version of SymbolTable of left or right if either symbol table is // a superset of the other. std::unique_ptr<SymbolTable> merged( new SymbolTable("merge_" + left.Name() + "_" + right.Name())); // Copies everything from the left symbol table. bool left_has_all = true; bool right_has_all = true; bool relabel = false; for (SymbolTableIterator liter(left); !liter.Done(); liter.Next()) { merged->AddSymbol(liter.Symbol(), liter.Value()); if (right_has_all) { int64 key = right.Find(liter.Symbol()); if (key == -1) { right_has_all = false; } else if (!relabel && key != liter.Value()) { relabel = true; } } } if (right_has_all) { if (right_relabel_output) *right_relabel_output = relabel; return right.Copy(); } // add all symbols we can from right symbol table std::vector<string> conflicts; for (SymbolTableIterator riter(right); !riter.Done(); riter.Next()) { int64 key = merged->Find(riter.Symbol()); if (key != -1) { // Symbol already exists, maybe with different value if (key != riter.Value()) relabel = true; continue; } // Symbol doesn't exist from left left_has_all = false; if (!merged->Find(riter.Value()).empty()) { // we can't add this where we want to, add it later, in order conflicts.push_back(riter.Symbol()); continue; } // there is a hole and we can add this symbol with its id merged->AddSymbol(riter.Symbol(), riter.Value()); } if (right_relabel_output) *right_relabel_output = relabel; if (left_has_all) return left.Copy(); // Add all symbols that conflicted, in order for (const auto &conflict : conflicts) merged->AddSymbol(conflict); return merged.release(); } SymbolTable *CompactSymbolTable(const SymbolTable &syms) { std::map<int64, string> sorted; SymbolTableIterator stiter(syms); for (; !stiter.Done(); stiter.Next()) { sorted[stiter.Value()] = stiter.Symbol(); } auto *compact = new SymbolTable(syms.Name() + "_compact"); int64 newkey = 0; for (const auto &kv : sorted) compact->AddSymbol(kv.second, newkey++); return compact; } SymbolTable *FstReadSymbols(const string &filename, bool input_symbols) { std::ifstream in(filename, std::ios_base::in | std::ios_base::binary); if (!in) { LOG(ERROR) << "FstReadSymbols: Can't open file " << filename; return nullptr; } FstHeader hdr; if (!hdr.Read(in, filename)) { LOG(ERROR) << "FstReadSymbols: Couldn't read header from " << filename; return nullptr; } if (hdr.GetFlags() & FstHeader::HAS_ISYMBOLS) { std::unique_ptr<SymbolTable> isymbols(SymbolTable::Read(in, filename)); if (isymbols == nullptr) { LOG(ERROR) << "FstReadSymbols: Couldn't read input symbols from " << filename; return nullptr; } if (input_symbols) return isymbols.release(); } if (hdr.GetFlags() & FstHeader::HAS_OSYMBOLS) { std::unique_ptr<SymbolTable> osymbols(SymbolTable::Read(in, filename)); if (osymbols == nullptr) { LOG(ERROR) << "FstReadSymbols: Couldn't read output symbols from " << filename; return nullptr; } if (!input_symbols) return osymbols.release(); } LOG(ERROR) << "FstReadSymbols: The file " << filename << " doesn't contain the requested symbols"; return nullptr; } bool AddAuxiliarySymbols(const string &prefix, int64 start_label, int64 nlabels, SymbolTable *syms) { for (int64 i = 0; i < nlabels; ++i) { auto index = i + start_label; if (index != syms->AddSymbol(prefix + std::to_string(i), index)) { FSTERROR() << "AddAuxiliarySymbols: Symbol table clash"; return false; } } return true; } } // namespace fst |