Blame view
scripts/rnnlm/get_special_symbol_opts.py
2.18 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
#!/usr/bin/env python3 # Copyright 2017 Jian Wang # License: Apache 2.0. import io import os import argparse import sys import re parser = argparse.ArgumentParser(description="This script checks whether the special symbols " "appear in words.txt with expected values, if not, it will " "print out the options with correct value to stdout, which may look like " "'--bos-symbol=14312 --eos-symbol=14313 --brk-symbol=14320'.", epilog="E.g. " + sys.argv[0] + " < exp/rnnlm/config/words.txt > exp/rnnlm/special_symbol_opts.txt", formatter_class=argparse.ArgumentDefaultsHelpFormatter) args = parser.parse_args() # this dict stores the special_symbols and their corresponding (expected_id, option_name) special_symbols = {'<s>': (1, '--bos-symbol'), '</s>': (2, '--eos-symbol'), '<brk>': (3, '--brk-symbol')} upper_special_symbols = [key.upper() for key in special_symbols] lower_ids = {} upper_ids = {} input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') for line in input_stream: fields = line.split() assert(len(fields) == 2) sym = fields[0] if sym in special_symbols: assert sym not in lower_ids lower_ids[sym] = int(fields[1]) elif sym in upper_special_symbols: assert sym.lower() not in upper_ids upper_ids[sym.lower()] = int(fields[1]) printed = False for sym in special_symbols: if sym in lower_ids: if special_symbols[sym][0] != lower_ids[sym]: print('{0}={1} '.format(special_symbols[sym][1], lower_ids[sym]), end='') printed = True if sym in upper_ids: print(sys.argv[0] + ": both uppercase and lowercase are present for " + sym, file=sys.stderr) elif sym in upper_ids: if special_symbols[sym][0] != upper_ids[sym]: print('{0}={1} '.format(special_symbols[sym][1], upper_ids[sym]), end='') printed = True else: raise ValueError("Special symbol is not appeared: " + sym) sys.exit(1) if printed: print('') |