Blame view

scripts/rnnlm/validate_text_dir.py 4.83 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
  #!/usr/bin/env python3
  
  # Copyright  2017  Jian Wang
  # License: Apache 2.0.
  
  import os
  import argparse
  import sys
  
  import re
  
  
  parser = argparse.ArgumentParser(description="Validates data directory containing text "
                                   "files from one or more data sources, including dev.txt.",
                                   epilog="E.g. " + sys.argv[0] + " data/rnnlm/data",
                                   formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  
  parser.add_argument("--spot-check", type=str, default='true',
                      choices=['true', 'false'],
                      help="If true, only do spot check on text files.")
  parser.add_argument("--allow-internal-eos", type=str, default='true',
                      choices=['true', 'false'],
                      help="If true, allow internal </s> in lines of the text.")
  parser.add_argument("text_dir",
                      help="Directory in which to look for text data")
  
  args = parser.parse_args()
  
  EOS_SYMBOL = '</s>'
  SPECIAL_SYMBOLS = ['<s>', '<brk>', '<eps>']
  
  if not os.path.exists(args.text_dir):
      sys.exit(sys.argv[0] + ": Expected directory {0} to exist".format(args.text_dir))
  
  if not os.path.exists("{0}/dev.txt".format(args.text_dir)):
      sys.exit(sys.argv[0] + ": Expected file {0}/dev.txt to exist".format(args.text_dir))
  
  
  num_text_files = 0
  
  
  def check_text_file(text_file):
      with open(text_file, 'r', encoding="utf-8") as f:
          found_nonempty_line = False
          lineno = 0
          if args.allow_internal_eos == 'true':
              disallowed_symbols = SPECIAL_SYMBOLS
          else:
              disallowed_symbols = SPECIAL_SYMBOLS + [EOS_SYMBOL]
          for line in f:
              line = line.strip("
  ")
              if line is None:
                  break
              lineno += 1
              if args.spot_check == 'true' and lineno > 10:
                  break
              words = line.split()
              if len(words) != 0:
                  found_nonempty_line = True
                  for word in words:
                      if word in disallowed_symbols:
                          sys.exit(sys.argv[0] + ": Found suspicious line '{0}' in file {1} at {2} ({3} "
                                   " symbol is disallowed!)".format(line, text_file, lineno, word))
                  if words[-1] == EOS_SYMBOL:
                      sys.exit(sys.argv[0] + ": Found suspicious line '{0}' in file {1} at {2} (EOS symbol "
                               "at the end of a line is disallowed!)".format(line, text_file, lineno))
                  if len(words) >= 1000:
                      print(sys.argv[0] + ": Too long line with {0} words in file "
                            "{1} at {2}".format(len(words), text_file, lineno), file=sys.stderr)
      if not found_nonempty_line:
          sys.exit(sys.argv[0] + ": Input file {0} doesn't look right.".format(text_file))
  
      # Next we're going to check that it's not the case
      # that the first and second fields have disjoint words on them, and the
      # first field is always unique, which would be the case if the lines started
      # with some kind of utterance-id
      first_field_set = set()
      other_fields_set = set()
      with open(text_file, 'r', encoding="utf-8") as f:
          for line in f:
              array = line.split()
              if len(array) > 0:
                  first_word = array[0]
                  if first_word in first_field_set or first_word in other_fields_set:
                      # the first field isn't always unique, or is shared with other
                      # fields.
                      return
                  first_field_set.add(first_word)
              for i in range(1, len(array)):
                  other_word = array[i]
                  if other_word in first_field_set:
                      # the first field has a value shared by some word not in the
                      # first position.
                      return
                  other_fields_set.add(other_word)
      print(sys.argv[0] + ": input file {0} looks suspicious; check that you "
            "don't have utterance-ids in the first field (i.e. you shouldn't provide "
            "lines that look like 'utterance-id1 hello there').  Ignore this warning "
            "if you don't have that problem.".format(text_file), file=sys.stderr)
  
  
  for f in os.listdir(args.text_dir):
      full_path = args.text_dir + "/" + f
      if os.path.isdir(full_path) or f.endswith(".counts"):
          continue
      if not f.endswith(".txt"):
          sys.exit(sys.argv[0] + ": Text directory should not contain files with suffixes "
                   "other than .txt and .counts: " + f)
      if not os.path.isfile(full_path):
          sys.exit(sys.argv[0] + ": Expected {0} to be a file.".format(full_path))
      check_text_file(full_path)
      num_text_files += 1
  
  if num_text_files < 2:
      sys.exit(sys.argv[0] + ": Directory {0} should contain at least one .txt file "
               "other than dev.txt.".format(args.text_dir))