Blame view
scripts/rnnlm/get_vocab.py
1.97 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
#!/usr/bin/env python3 # Copyright 2017 Jian Wang # License: Apache 2.0. import os import argparse import sys sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) import re parser = argparse.ArgumentParser(description="This script get a vocab from unigram counts " "of words produced by get_unigram_counts.sh", epilog="E.g. " + sys.argv[0] + " data/rnnlm/data > data/rnnlm/vocab/words.txt", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("data_dir", help="Directory in which to look for unigram counts.") args = parser.parse_args() eos_symbol = '</s>' special_symbols = ['<s>', '<brk>', '<eps>'] # Add the count for every word in counts_file # the result is written into word_counts def add_counts(word_counts, counts_file): with open(counts_file, 'r', encoding="utf-8") as f: for line in f: line = line.strip(" \t\r ") word_and_count = line.split() assert len(word_and_count) == 2 if word_and_count[0] in word_counts: word_counts[word_and_count[0]] += int(word_and_count[1]) else: word_counts[word_and_count[0]] = int(word_and_count[1]) word_counts = {} for f in os.listdir(args.data_dir): full_path = args.data_dir + "/" + f if os.path.isdir(full_path): continue if f.endswith(".counts"): add_counts(word_counts, full_path) if len(word_counts) == 0: sys.exit(sys.argv[0] + ": Directory {0} should contain at least one .counts file " .format(args.data_dir)) print("<eps> 0") print("<s> 1") print("</s> 2") print("<brk> 3") idx = 4 for word, _ in sorted(word_counts.items(), key=lambda x: x[1], reverse=True): if word == "</s>": continue print("{0} {1}".format(word, idx)) idx += 1 print(sys.argv[0] + ": vocab is generated with {0} words.".format(idx), file=sys.stderr) |