Blame view

scripts/rnnlm/validate_features.py 3.06 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
  #!/usr/bin/env python3
  
  # Copyright  2017  Jian Wang
  # License: Apache 2.0.
  
  import os
  import argparse
  import sys
  
  import re
  
  
  parser = argparse.ArgumentParser(description="Validates features file, produced by rnnlm/choose_features.py.",
                                   epilog="E.g. " + sys.argv[0] + " exp/rnnlm/features.txt",
                                   formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  
  parser.add_argument("features_file",
                      help="File containing features")
  
  args = parser.parse_args()
  
  EOS_SYMBOL = '</s>'
  
  if not os.path.isfile(args.features_file):
      sys.exit(sys.argv[0] + ": Expected file {0} to exist".format(args.features_file))
  
  with open(args.features_file, 'r', encoding="utf-8") as f:
      has_unigram = False
      has_length = False
      idx = 0
      match_feats = {}
      inital_feats = {}
      final_feats = {}
      word_feats = {}
      for line in f:
          fields = line.split()
          assert(len(fields) in [3, 4, 5])
  
          assert idx == int(fields[0])
          idx += 1
  
          # every feature should contain a scale
          scale = float(fields[-1])
          assert scale > 0.0 and scale <= 1.0
  
          if len(fields) == 3 and fields[1] == "length":
              if has_length:
                  sys.exit(sys.argv[0] + ": Too many 'length' features")
              has_length = True
          else:
              if fields[1] == "constant":
                  try:
                      assert len(fields) == 3
                      value = float(fields[2])
                      assert value > 0.0
                  except:
                      sys.exit(sys.argv[0] + ": bad line: {0}".format(line))
              elif fields[1] == "special":
                  if len(fields) != 4:
                      sys.exit(sys.argv[0] + ": bad line: {0}".format(line))
              elif fields[1] == "unigram":
                  if float(fields[2]) <= 0.0:
                      sys.exit(sys.argv[0] + ": log-unigram-ppl should be a positive value: {0}".format(fields[2]))
                  if has_unigram:
                      sys.exit(sys.argv[0] + ": Too many 'unigram' features")
                  has_unigram = True
              elif fields[1] == "word":
                  if fields[2] in word_feats:
                      sys.exit(sys.argv[0] + ": duplicated word feature: {0}".format(fields[2]))
                  word_feats[fields[2]] = 1
              elif fields[1] == "initial":
                  if fields[2] in inital_feats:
                      sys.exit(sys.argv[0] + ": duplicated initial feature: {0}".format(fields[2]))
                  inital_feats[fields[2]] = 1
              elif fields[1] == "final":
                  if fields[2] in final_feats:
                      sys.exit(sys.argv[0] + ": duplicated final feature: {0}".format(fields[2]))
                  final_feats[fields[2]] = 1
              elif fields[1] == "match":
                  if fields[2] in match_feats:
                      sys.exit(sys.argv[0] + ": duplicated match feature: {0}".format(fields[2]))
                  match_feats[fields[2]] = 1
              else:
                  sys.exit(sys.argv[0] + ": Error line format: {0}".format(line))