get_vocab.py
1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env python3
# Copyright 2017 Jian Wang
# License: Apache 2.0.
import os
import argparse
import sys
sys.stdout = open(1, 'w', encoding='utf-8', closefd=False)
import re
parser = argparse.ArgumentParser(description="This script get a vocab from unigram counts "
"of words produced by get_unigram_counts.sh",
epilog="E.g. " + sys.argv[0] + " data/rnnlm/data > data/rnnlm/vocab/words.txt",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("data_dir",
help="Directory in which to look for unigram counts.")
args = parser.parse_args()
eos_symbol = '</s>'
special_symbols = ['<s>', '<brk>', '<eps>']
# Add the count for every word in counts_file
# the result is written into word_counts
def add_counts(word_counts, counts_file):
with open(counts_file, 'r', encoding="utf-8") as f:
for line in f:
line = line.strip(" \t\r\n")
word_and_count = line.split()
assert len(word_and_count) == 2
if word_and_count[0] in word_counts:
word_counts[word_and_count[0]] += int(word_and_count[1])
else:
word_counts[word_and_count[0]] = int(word_and_count[1])
word_counts = {}
for f in os.listdir(args.data_dir):
full_path = args.data_dir + "/" + f
if os.path.isdir(full_path):
continue
if f.endswith(".counts"):
add_counts(word_counts, full_path)
if len(word_counts) == 0:
sys.exit(sys.argv[0] + ": Directory {0} should contain at least one .counts file "
.format(args.data_dir))
print("<eps> 0")
print("<s> 1")
print("</s> 2")
print("<brk> 3")
idx = 4
for word, _ in sorted(word_counts.items(), key=lambda x: x[1], reverse=True):
if word == "</s>":
continue
print("{0} {1}".format(word, idx))
idx += 1
print(sys.argv[0] + ": vocab is generated with {0} words.".format(idx), file=sys.stderr)