Blame view

tools/openfst-1.6.7/src/lib/symbol-table-ops.cc 4.09 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
  // See www.openfst.org for extensive documentation on this weighted
  // finite-state transducer library.
  //
  
  #include <fst/symbol-table-ops.h>
  
  #include <string>
  
  namespace fst {
  
  SymbolTable *MergeSymbolTable(const SymbolTable &left, const SymbolTable &right,
                                bool *right_relabel_output) {
    // MergeSymbolTable detects several special cases.  It will return a reference
    // copied version of SymbolTable of left or right if either symbol table is
    // a superset of the other.
    std::unique_ptr<SymbolTable> merged(
        new SymbolTable("merge_" + left.Name() + "_" + right.Name()));
    // Copies everything from the left symbol table.
    bool left_has_all = true;
    bool right_has_all = true;
    bool relabel = false;
    for (SymbolTableIterator liter(left); !liter.Done(); liter.Next()) {
      merged->AddSymbol(liter.Symbol(), liter.Value());
      if (right_has_all) {
        int64 key = right.Find(liter.Symbol());
        if (key == -1) {
          right_has_all = false;
        } else if (!relabel && key != liter.Value()) {
          relabel = true;
        }
      }
    }
    if (right_has_all) {
      if (right_relabel_output) *right_relabel_output = relabel;
      return right.Copy();
    }
    // add all symbols we can from right symbol table
    std::vector<string> conflicts;
    for (SymbolTableIterator riter(right); !riter.Done(); riter.Next()) {
      int64 key = merged->Find(riter.Symbol());
      if (key != -1) {
        // Symbol already exists, maybe with different value
        if (key != riter.Value()) relabel = true;
        continue;
      }
      // Symbol doesn't exist from left
      left_has_all = false;
      if (!merged->Find(riter.Value()).empty()) {
        // we can't add this where we want to, add it later, in order
        conflicts.push_back(riter.Symbol());
        continue;
      }
      // there is a hole and we can add this symbol with its id
      merged->AddSymbol(riter.Symbol(), riter.Value());
    }
    if (right_relabel_output) *right_relabel_output = relabel;
    if (left_has_all) return left.Copy();
    // Add all symbols that conflicted, in order
    for (const auto &conflict : conflicts) merged->AddSymbol(conflict);
    return merged.release();
  }
  
  SymbolTable *CompactSymbolTable(const SymbolTable &syms) {
    std::map<int64, string> sorted;
    SymbolTableIterator stiter(syms);
    for (; !stiter.Done(); stiter.Next()) {
      sorted[stiter.Value()] = stiter.Symbol();
    }
    auto *compact = new SymbolTable(syms.Name() + "_compact");
    int64 newkey = 0;
    for (const auto &kv : sorted) compact->AddSymbol(kv.second, newkey++);
    return compact;
  }
  
  SymbolTable *FstReadSymbols(const string &filename, bool input_symbols) {
    std::ifstream in(filename, std::ios_base::in | std::ios_base::binary);
    if (!in) {
      LOG(ERROR) << "FstReadSymbols: Can't open file " << filename;
      return nullptr;
    }
    FstHeader hdr;
    if (!hdr.Read(in, filename)) {
      LOG(ERROR) << "FstReadSymbols: Couldn't read header from " << filename;
      return nullptr;
    }
    if (hdr.GetFlags() & FstHeader::HAS_ISYMBOLS) {
      std::unique_ptr<SymbolTable> isymbols(SymbolTable::Read(in, filename));
      if (isymbols == nullptr) {
        LOG(ERROR) << "FstReadSymbols: Couldn't read input symbols from "
                   << filename;
        return nullptr;
      }
      if (input_symbols) return isymbols.release();
    }
    if (hdr.GetFlags() & FstHeader::HAS_OSYMBOLS) {
      std::unique_ptr<SymbolTable> osymbols(SymbolTable::Read(in, filename));
      if (osymbols == nullptr) {
        LOG(ERROR) << "FstReadSymbols: Couldn't read output symbols from "
                   << filename;
        return nullptr;
      }
      if (!input_symbols) return osymbols.release();
    }
    LOG(ERROR) << "FstReadSymbols: The file " << filename
               << " doesn't contain the requested symbols";
    return nullptr;
  }
  
  bool AddAuxiliarySymbols(const string &prefix, int64 start_label,
                           int64 nlabels, SymbolTable *syms) {
    for (int64 i = 0; i < nlabels; ++i) {
      auto index = i + start_label;
      if (index != syms->AddSymbol(prefix + std::to_string(i), index)) {
        FSTERROR() << "AddAuxiliarySymbols: Symbol table clash";
        return false;
      }
    }
    return true;
  }
  
  }  // namespace fst