symbol-table.h 14.4 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Classes to provide symbol-to-integer and integer-to-symbol mappings.

#ifndef FST_SYMBOL_TABLE_H_
#define FST_SYMBOL_TABLE_H_

#include <cstring>
#include <functional>
#include <ios>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#include <fst/compat.h>
#include <fst/flags.h>
#include <fst/log.h>
#include <fstream>
#include <map>

DECLARE_bool(fst_compat_symbols);

namespace fst {

constexpr int64 kNoSymbol = -1;

// WARNING: Reading via symbol table read options should
//          not be used. This is a temporary work around for
//          reading symbol ranges of previously stored symbol sets.
struct SymbolTableReadOptions {
  SymbolTableReadOptions() {}

  SymbolTableReadOptions(
      std::vector<std::pair<int64, int64>> string_hash_ranges,
      const string &source)
      : string_hash_ranges(std::move(string_hash_ranges)), source(source) {}

  std::vector<std::pair<int64, int64>> string_hash_ranges;
  string source;
};

struct SymbolTableTextOptions {
  explicit SymbolTableTextOptions(bool allow_negative_labels = false);

  bool allow_negative_labels;
  string fst_field_separator;
};

namespace internal {

// List of symbols with a dense hash for looking up symbol index.
// Hash uses linear probe, rehashes at 0.75% occupancy, avg 6 bytes overhead
// per entry.  Rehash in place from symbol list.
//
// Symbols are stored as c strings to avoid adding memory overhead, but the
// performance penalty for this is high because rehash must call strlen on
// every symbol.  AddSymbol can be another 2x faster if symbol lengths were
// stored.
class DenseSymbolMap {
 public:
  DenseSymbolMap();

  DenseSymbolMap(const DenseSymbolMap &x);

  ~DenseSymbolMap();

  std::pair<int64, bool> InsertOrFind(const string &key);

  int64 Find(const string &key) const;

  const size_t size() const { return symbols_.size(); }

  const string GetSymbol(size_t idx) const {
    return string(symbols_[idx], strlen(symbols_[idx]));
  }

  void RemoveSymbol(size_t idx);

 private:
  // num_buckets must be power of 2.
  void Rehash(size_t num_buckets);

  const char* NewSymbol(const string &sym);

  int64 empty_;
  std::vector<const char *> symbols_;
  std::hash<string> str_hash_;
  std::vector<int64> buckets_;
  uint64 hash_mask_;
};

class SymbolTableImpl {
 public:
  explicit SymbolTableImpl(const string &name)
      : name_(name),
        available_key_(0),
        dense_key_limit_(0),
        check_sum_finalized_(false) {}

  SymbolTableImpl(const SymbolTableImpl &impl)
      : name_(impl.name_),
        available_key_(impl.available_key_),
        dense_key_limit_(impl.dense_key_limit_),
        symbols_(impl.symbols_),
        idx_key_(impl.idx_key_),
        key_map_(impl.key_map_),
        check_sum_finalized_(false) {}

  int64 AddSymbol(const string &symbol, int64 key);

  int64 AddSymbol(const string &symbol) {
    return AddSymbol(symbol, available_key_);
  }

  // Removes the symbol with the given key. The removal is costly
  // (O(NumSymbols)) and may reduce the efficiency of Find() because of a
  // potentially reduced size of the dense key interval.
  void RemoveSymbol(int64 key);

  static SymbolTableImpl *ReadText(
      std::istream &strm, const string &name,
      const SymbolTableTextOptions &opts = SymbolTableTextOptions());

  static SymbolTableImpl* Read(std::istream &strm,
                               const SymbolTableReadOptions &opts);

  bool Write(std::ostream &strm) const;

  // Return the string associated with the key. If the key is out of
  // range (<0, >max), return an empty string.
  string Find(int64 key) const {
    int64 idx = key;
    if (key < 0 || key >= dense_key_limit_) {
      const auto it = key_map_.find(key);
      if (it == key_map_.end()) return "";
      idx = it->second;
    }
    if (idx < 0 || idx >= symbols_.size()) return "";
    return symbols_.GetSymbol(idx);
  }

  // Returns the key associated with the symbol; if the symbol
  // does not exists, returns kNoSymbol.
  int64 Find(const string &symbol) const {
    int64 idx = symbols_.Find(symbol);
    if (idx == kNoSymbol || idx < dense_key_limit_) return idx;
    return idx_key_[idx - dense_key_limit_];
  }

  bool Member(int64 key) const { return !Find(key).empty(); }

  bool Member(const string &symbol) const { return Find(symbol) != kNoSymbol; }

  int64 GetNthKey(ssize_t pos) const {
    if (pos < 0 || pos >= symbols_.size()) return kNoSymbol;
    if (pos < dense_key_limit_) return pos;
    return Find(symbols_.GetSymbol(pos));
  }

  const string &Name() const { return name_; }

  void SetName(const string &new_name) { name_ = new_name; }

  const string &CheckSum() const {
    MaybeRecomputeCheckSum();
    return check_sum_string_;
  }

  const string &LabeledCheckSum() const {
    MaybeRecomputeCheckSum();
    return labeled_check_sum_string_;
  }

  int64 AvailableKey() const { return available_key_; }

  size_t NumSymbols() const { return symbols_.size(); }

 private:
  // Recomputes the checksums (both of them) if we've had changes since the last
  // computation (i.e., if check_sum_finalized_ is false).
  // Takes ~2.5 microseconds (dbg) or ~230 nanoseconds (opt) on a 2.67GHz Xeon
  // if the checksum is up-to-date (requiring no recomputation).
  void MaybeRecomputeCheckSum() const;

  string name_;
  int64 available_key_;
  int64 dense_key_limit_;

  DenseSymbolMap symbols_;
  // Maps index to key for index >= dense_key_limit:
  //   key = idx_key_[index - dense_key_limit]
  std::vector<int64> idx_key_;
  // Maps key to index for key >= dense_key_limit_.
  //  index = key_map_[key]
  map<int64, int64> key_map_;

  mutable bool check_sum_finalized_;
  mutable string check_sum_string_;
  mutable string labeled_check_sum_string_;
  mutable Mutex check_sum_mutex_;
};

}  // namespace internal

// Symbol (string) to integer (and reverse) mapping.
//
// The SymbolTable implements the mappings of labels to strings and reverse.
// SymbolTables are used to describe the alphabet of the input and output
// labels for arcs in a Finite State Transducer.
//
// SymbolTables are reference-counted and can therefore be shared across
// multiple machines. For example a language model grammar G, with a
// SymbolTable for the words in the language model can share this symbol
// table with the lexical representation L o G.
class SymbolTable {
 public:
  // Constructs symbol table with an optional name.
  explicit SymbolTable(const string &name = "<unspecified>")
      : impl_(std::make_shared<internal::SymbolTableImpl>(name)) {}

  virtual ~SymbolTable() {}

  // Reads a text representation of the symbol table from an istream. Pass a
  // name to give the resulting SymbolTable.
  static SymbolTable *ReadText(
      std::istream &strm, const string &name,
      const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
    auto *impl = internal::SymbolTableImpl::ReadText(strm, name, opts);
    return impl ? new SymbolTable(impl) : nullptr;
  }

  // Reads a text representation of the symbol table.
  static SymbolTable *ReadText(const string &filename,
      const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
    std::ifstream strm(filename, std::ios_base::in);
    if (!strm.good()) {
      LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename;
      return nullptr;
    }
    return ReadText(strm, filename, opts);
  }

  // WARNING: Reading via symbol table read options should not be used. This is
  // a temporary work-around.
  static SymbolTable* Read(std::istream &strm,
                           const SymbolTableReadOptions &opts) {
    auto *impl = internal::SymbolTableImpl::Read(strm, opts);
    return (impl) ? new SymbolTable(impl) : nullptr;
  }

  // Reads a binary dump of the symbol table from a stream.
  static SymbolTable *Read(std::istream &strm,
                           const string &source) {
    SymbolTableReadOptions opts;
    opts.source = source;
    return Read(strm, opts);
  }

  // Reads a binary dump of the symbol table.
  static SymbolTable *Read(const string& filename) {
    std::ifstream strm(filename,
                            std::ios_base::in | std::ios_base::binary);
    if (!strm.good()) {
      LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
      return nullptr;
    }
    return Read(strm, filename);
  }

  //--------------------------------------------------------
  // Derivable Interface (final)
  //--------------------------------------------------------
  // Creates a reference counted copy.
  virtual SymbolTable *Copy() const { return new SymbolTable(*this); }

  // Adds a symbol with given key to table. A symbol table also keeps track of
  // the last available key (highest key value in the symbol table).
  virtual int64 AddSymbol(const string &symbol, int64 key) {
    MutateCheck();
    return impl_->AddSymbol(symbol, key);
  }

  // Adds a symbol to the table. The associated value key is automatically
  // assigned by the symbol table.
  virtual int64 AddSymbol(const string &symbol) {
    MutateCheck();
    return impl_->AddSymbol(symbol);
  }

  // Adds another symbol table to this table. All key values will be offset
  // by the current available key (highest key value in the symbol table).
  // Note string symbols with the same key value will still have the same
  // key value after the symbol table has been merged, but a different
  // value. Adding symbol tables do not result in changes in the base table.
  virtual void AddTable(const SymbolTable &table);

  virtual void RemoveSymbol(int64 key) {
    MutateCheck();
    return impl_->RemoveSymbol(key);
  }

  // Returns the name of the symbol table.
  virtual const string &Name() const { return impl_->Name(); }

  // Sets the name of the symbol table.
  virtual void SetName(const string &new_name) {
    MutateCheck();
    impl_->SetName(new_name);
  }

  // Return the label-agnostic MD5 check-sum for this table. All new symbols
  // added to the table will result in an updated checksum. Deprecated.
  virtual const string &CheckSum() const { return impl_->CheckSum(); }

  // Same as CheckSum(), but returns an label-dependent version.
  virtual const string &LabeledCheckSum() const {
    return impl_->LabeledCheckSum();
  }

  virtual bool Write(std::ostream &strm) const { return impl_->Write(strm); }

  bool Write(const string &filename) const {
    std::ofstream strm(filename,
                             std::ios_base::out | std::ios_base::binary);
    if (!strm.good()) {
      LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
      return false;
    }
    return Write(strm);
  }

  // Dump a text representation of the symbol table via a stream.
  virtual bool WriteText(std::ostream &strm,
      const SymbolTableTextOptions &opts = SymbolTableTextOptions()) const;

  // Dump an text representation of the symbol table.
  bool WriteText(const string &filename) const {
    std::ofstream strm(filename);
    if (!strm.good()) {
      LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
      return false;
    }
    return WriteText(strm);
  }

  // Returns the string associated with the key; if the key is out of
  // range (<0, >max), returns an empty string.
  virtual string Find(int64 key) const { return impl_->Find(key); }

  // Returns the key associated with the symbol; if the symbol does not exist,
  // kNoSymbol is returned.
  virtual int64 Find(const string &symbol) const { return impl_->Find(symbol); }

  // Returns the key associated with the symbol; if the symbol does not exist,
  // kNoSymbol is returned.
  virtual int64 Find(const char *symbol) const { return impl_->Find(symbol); }

  virtual bool Member(int64 key) const { return impl_->Member(key); }

  virtual bool Member(const string &symbol) const {
    return impl_->Member(symbol);
  }

  // Returns the current available key (i.e., highest key + 1) in the symbol
  // table.
  virtual int64 AvailableKey() const { return impl_->AvailableKey(); }

  // Returns the current number of symbols in table (not necessarily equal to
  // AvailableKey()).
  virtual size_t NumSymbols() const { return impl_->NumSymbols(); }

  virtual int64 GetNthKey(ssize_t pos) const { return impl_->GetNthKey(pos); }

 private:
  explicit SymbolTable(internal::SymbolTableImpl *impl) : impl_(impl) {}

  void MutateCheck() {
    if (!impl_.unique()) impl_.reset(new internal::SymbolTableImpl(*impl_));
  }

  const internal::SymbolTableImpl *Impl() const { return impl_.get(); }

 private:
  std::shared_ptr<internal::SymbolTableImpl> impl_;
};

// Iterator class for symbols in a symbol table.
class SymbolTableIterator {
 public:
  explicit SymbolTableIterator(const SymbolTable &table)
      : table_(table),
        pos_(0),
        nsymbols_(table.NumSymbols()),
        key_(table.GetNthKey(0)) {}

  ~SymbolTableIterator() {}

  // Returns whether iterator is done.
  bool Done() const { return (pos_ == nsymbols_); }

  // Return the key of the current symbol.
  int64 Value() const { return key_; }

  // Return the string of the current symbol.
  string Symbol() const { return table_.Find(key_); }

  // Advances iterator.
  void Next() {
    ++pos_;
    if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_);
  }

  // Resets iterator.
  void Reset() {
    pos_ = 0;
    key_ = table_.GetNthKey(0);
  }

 private:
  const SymbolTable &table_;
  ssize_t pos_;
  size_t nsymbols_;
  int64 key_;
};

// Relabels a symbol table as specified by the input vector of pairs
// (old label, new label). The new symbol table only retains symbols
// for which a relabeling is explicitly specified.
//
// TODO(allauzen): consider adding options to allow for some form of implicit
// identity relabeling.
template <class Label>
SymbolTable *RelabelSymbolTable(const SymbolTable *table,
    const std::vector<std::pair<Label, Label>> &pairs) {
  auto new_table = new SymbolTable(table->Name().empty() ?
      string() : (string("relabeled_") + table->Name()));
  for (const auto &pair : pairs) {
    new_table->AddSymbol(table->Find(pair.first), pair.second);
  }
  return new_table;
}

// Returns true if the two symbol tables have equal checksums. Passing in
// nullptr for either table always returns true.
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
                   bool warning = true);

// Symbol Table serialization.

void SymbolTableToString(const SymbolTable *table, string *result);

SymbolTable *StringToSymbolTable(const string &str);

}  // namespace fst

#endif  // FST_SYMBOL_TABLE_H_