validate_text_dir.py
4.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python3
# Copyright 2017 Jian Wang
# License: Apache 2.0.
import os
import argparse
import sys
import re
parser = argparse.ArgumentParser(description="Validates data directory containing text "
"files from one or more data sources, including dev.txt.",
epilog="E.g. " + sys.argv[0] + " data/rnnlm/data",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--spot-check", type=str, default='true',
choices=['true', 'false'],
help="If true, only do spot check on text files.")
parser.add_argument("--allow-internal-eos", type=str, default='true',
choices=['true', 'false'],
help="If true, allow internal </s> in lines of the text.")
parser.add_argument("text_dir",
help="Directory in which to look for text data")
args = parser.parse_args()
EOS_SYMBOL = '</s>'
SPECIAL_SYMBOLS = ['<s>', '<brk>', '<eps>']
if not os.path.exists(args.text_dir):
sys.exit(sys.argv[0] + ": Expected directory {0} to exist".format(args.text_dir))
if not os.path.exists("{0}/dev.txt".format(args.text_dir)):
sys.exit(sys.argv[0] + ": Expected file {0}/dev.txt to exist".format(args.text_dir))
num_text_files = 0
def check_text_file(text_file):
with open(text_file, 'r', encoding="utf-8") as f:
found_nonempty_line = False
lineno = 0
if args.allow_internal_eos == 'true':
disallowed_symbols = SPECIAL_SYMBOLS
else:
disallowed_symbols = SPECIAL_SYMBOLS + [EOS_SYMBOL]
for line in f:
line = line.strip("\n")
if line is None:
break
lineno += 1
if args.spot_check == 'true' and lineno > 10:
break
words = line.split()
if len(words) != 0:
found_nonempty_line = True
for word in words:
if word in disallowed_symbols:
sys.exit(sys.argv[0] + ": Found suspicious line '{0}' in file {1} at {2} ({3} "
" symbol is disallowed!)".format(line, text_file, lineno, word))
if words[-1] == EOS_SYMBOL:
sys.exit(sys.argv[0] + ": Found suspicious line '{0}' in file {1} at {2} (EOS symbol "
"at the end of a line is disallowed!)".format(line, text_file, lineno))
if len(words) >= 1000:
print(sys.argv[0] + ": Too long line with {0} words in file "
"{1} at {2}".format(len(words), text_file, lineno), file=sys.stderr)
if not found_nonempty_line:
sys.exit(sys.argv[0] + ": Input file {0} doesn't look right.".format(text_file))
# Next we're going to check that it's not the case
# that the first and second fields have disjoint words on them, and the
# first field is always unique, which would be the case if the lines started
# with some kind of utterance-id
first_field_set = set()
other_fields_set = set()
with open(text_file, 'r', encoding="utf-8") as f:
for line in f:
array = line.split()
if len(array) > 0:
first_word = array[0]
if first_word in first_field_set or first_word in other_fields_set:
# the first field isn't always unique, or is shared with other
# fields.
return
first_field_set.add(first_word)
for i in range(1, len(array)):
other_word = array[i]
if other_word in first_field_set:
# the first field has a value shared by some word not in the
# first position.
return
other_fields_set.add(other_word)
print(sys.argv[0] + ": input file {0} looks suspicious; check that you "
"don't have utterance-ids in the first field (i.e. you shouldn't provide "
"lines that look like 'utterance-id1 hello there'). Ignore this warning "
"if you don't have that problem.".format(text_file), file=sys.stderr)
for f in os.listdir(args.text_dir):
full_path = args.text_dir + "/" + f
if os.path.isdir(full_path) or f.endswith(".counts"):
continue
if not f.endswith(".txt"):
sys.exit(sys.argv[0] + ": Text directory should not contain files with suffixes "
"other than .txt and .counts: " + f)
if not os.path.isfile(full_path):
sys.exit(sys.argv[0] + ": Expected {0} to be a file.".format(full_path))
check_text_file(full_path)
num_text_files += 1
if num_text_files < 2:
sys.exit(sys.argv[0] + ": Directory {0} should contain at least one .txt file "
"other than dev.txt.".format(args.text_dir))