symbol-table-ops.cc
4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
#include <fst/symbol-table-ops.h>
#include <string>
namespace fst {
SymbolTable *MergeSymbolTable(const SymbolTable &left, const SymbolTable &right,
bool *right_relabel_output) {
// MergeSymbolTable detects several special cases. It will return a reference
// copied version of SymbolTable of left or right if either symbol table is
// a superset of the other.
std::unique_ptr<SymbolTable> merged(
new SymbolTable("merge_" + left.Name() + "_" + right.Name()));
// Copies everything from the left symbol table.
bool left_has_all = true;
bool right_has_all = true;
bool relabel = false;
for (SymbolTableIterator liter(left); !liter.Done(); liter.Next()) {
merged->AddSymbol(liter.Symbol(), liter.Value());
if (right_has_all) {
int64 key = right.Find(liter.Symbol());
if (key == -1) {
right_has_all = false;
} else if (!relabel && key != liter.Value()) {
relabel = true;
}
}
}
if (right_has_all) {
if (right_relabel_output) *right_relabel_output = relabel;
return right.Copy();
}
// add all symbols we can from right symbol table
std::vector<string> conflicts;
for (SymbolTableIterator riter(right); !riter.Done(); riter.Next()) {
int64 key = merged->Find(riter.Symbol());
if (key != -1) {
// Symbol already exists, maybe with different value
if (key != riter.Value()) relabel = true;
continue;
}
// Symbol doesn't exist from left
left_has_all = false;
if (!merged->Find(riter.Value()).empty()) {
// we can't add this where we want to, add it later, in order
conflicts.push_back(riter.Symbol());
continue;
}
// there is a hole and we can add this symbol with its id
merged->AddSymbol(riter.Symbol(), riter.Value());
}
if (right_relabel_output) *right_relabel_output = relabel;
if (left_has_all) return left.Copy();
// Add all symbols that conflicted, in order
for (const auto &conflict : conflicts) merged->AddSymbol(conflict);
return merged.release();
}
SymbolTable *CompactSymbolTable(const SymbolTable &syms) {
std::map<int64, string> sorted;
SymbolTableIterator stiter(syms);
for (; !stiter.Done(); stiter.Next()) {
sorted[stiter.Value()] = stiter.Symbol();
}
auto *compact = new SymbolTable(syms.Name() + "_compact");
int64 newkey = 0;
for (const auto &kv : sorted) compact->AddSymbol(kv.second, newkey++);
return compact;
}
SymbolTable *FstReadSymbols(const string &filename, bool input_symbols) {
std::ifstream in(filename, std::ios_base::in | std::ios_base::binary);
if (!in) {
LOG(ERROR) << "FstReadSymbols: Can't open file " << filename;
return nullptr;
}
FstHeader hdr;
if (!hdr.Read(in, filename)) {
LOG(ERROR) << "FstReadSymbols: Couldn't read header from " << filename;
return nullptr;
}
if (hdr.GetFlags() & FstHeader::HAS_ISYMBOLS) {
std::unique_ptr<SymbolTable> isymbols(SymbolTable::Read(in, filename));
if (isymbols == nullptr) {
LOG(ERROR) << "FstReadSymbols: Couldn't read input symbols from "
<< filename;
return nullptr;
}
if (input_symbols) return isymbols.release();
}
if (hdr.GetFlags() & FstHeader::HAS_OSYMBOLS) {
std::unique_ptr<SymbolTable> osymbols(SymbolTable::Read(in, filename));
if (osymbols == nullptr) {
LOG(ERROR) << "FstReadSymbols: Couldn't read output symbols from "
<< filename;
return nullptr;
}
if (!input_symbols) return osymbols.release();
}
LOG(ERROR) << "FstReadSymbols: The file " << filename
<< " doesn't contain the requested symbols";
return nullptr;
}
bool AddAuxiliarySymbols(const string &prefix, int64 start_label,
int64 nlabels, SymbolTable *syms) {
for (int64 i = 0; i < nlabels; ++i) {
auto index = i + start_label;
if (index != syms->AddSymbol(prefix + std::to_string(i), index)) {
FSTERROR() << "AddAuxiliarySymbols: Symbol table clash";
return false;
}
}
return true;
}
} // namespace fst