Blame view

scripts/rnnlm/get_special_symbol_opts.py 2.18 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
  #!/usr/bin/env python3
  
  # Copyright  2017  Jian Wang
  # License: Apache 2.0.
  
  import io
  import os
  import argparse
  import sys
  
  import re
  
  
  parser = argparse.ArgumentParser(description="This script checks whether the special symbols "
                                   "appear in words.txt with expected values, if not, it will "
                                   "print out the options with correct value to stdout, which may look like "
                                   "'--bos-symbol=14312 --eos-symbol=14313 --brk-symbol=14320'.",
                                   epilog="E.g. " + sys.argv[0] + " < exp/rnnlm/config/words.txt > exp/rnnlm/special_symbol_opts.txt",
                                   formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  
  args = parser.parse_args()
  
  # this dict stores the special_symbols and their corresponding (expected_id, option_name)
  special_symbols = {'<s>':   (1, '--bos-symbol'),
                     '</s>':  (2, '--eos-symbol'),
                     '<brk>': (3, '--brk-symbol')}
  upper_special_symbols = [key.upper() for key in special_symbols]
  
  lower_ids = {}
  upper_ids = {}
  input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
  for line in input_stream:
      fields = line.split()
      assert(len(fields) == 2)
      sym = fields[0]
      if sym in special_symbols:
          assert sym not in lower_ids
          lower_ids[sym] = int(fields[1])
      elif sym in upper_special_symbols:
          assert sym.lower() not in upper_ids
          upper_ids[sym.lower()] = int(fields[1])
  
  printed = False
  for sym in special_symbols:
      if sym in lower_ids:
          if special_symbols[sym][0] != lower_ids[sym]:
              print('{0}={1} '.format(special_symbols[sym][1], lower_ids[sym]), end='')
              printed = True
          if sym in upper_ids:
              print(sys.argv[0] + ": both uppercase and lowercase are present for " + sym,
                    file=sys.stderr)
      elif sym in upper_ids:
          if special_symbols[sym][0] != upper_ids[sym]:
              print('{0}={1} '.format(special_symbols[sym][1], upper_ids[sym]), end='')
              printed = True
      else:
          raise ValueError("Special symbol is not appeared: " + sym)
          sys.exit(1)
  if printed:
      print('')