Blame view

scripts/rnnlm/validate_word_features.py 3.37 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  #!/usr/bin/env python3
  
  # Copyright  2017  Jian Wang
  # License: Apache 2.0.
  
  import os
  import argparse
  import sys
  
  import re
  
  
  parser = argparse.ArgumentParser(description="Validates word features file, produced by rnnlm/get_word_features.py.",
                                   epilog="E.g. " + sys.argv[0] + " --features-file=exp/rnnlm/features.txt "
                                          "exp/rnnlm/word_feats.txt",
                                   formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  
  parser.add_argument("--features-file", type=str, default='', required=True,
                      help="File containing features")
  parser.add_argument("word_features_file", help="File containing word features")
  
  args = parser.parse_args()
  
  # we only need to know the feat_id for 'special', 'unigram' and 'length'
  special_feat_ids = []
  constant_feat_id = -1
  constant_feat_value = None
  unigram_feat_id = -1
  length_feat_id = -1
  max_feat_id = -1
  with open(args.features_file, 'r', encoding="utf-8") as f:
      for line in f:
          fields = line.split()
          assert(len(fields) in [3, 4, 5])
  
          feat_id = int(fields[0])
  
          # every feature should contain a scale
          scale = float(fields[-1])
          assert scale > 0.0 and scale <= 1.0
  
          if fields[1] == "special":
              special_feat_ids.append(feat_id)
          elif fields[1] == "constant":
              constant_feat_id = feat_id
              constant_feat_value = scale
          elif fields[1] == "unigram":
              unigram_feat_id = feat_id
          elif fields[1] == "length":
              length_feat_id = feat_id
  
          if feat_id > max_feat_id:
              max_feat_id = feat_id
  
  with open(args.word_features_file, 'r', encoding="utf-8") as f:
      for line in f:
          fields = line.split()
          assert len(fields) > 0 and len(fields) % 2 == 1
          word_id = int(fields[0])
  
          if len(fields) == 1:
              if word_id != 0:
                  sys.exit(sys.argv[0] + ": Only <eps> can have no feature: {0}.".format(line))
          i = 1
          last_feat_id = -1
          while i < len(fields):
              feat_id = int(fields[i])
              feat_value = fields[i + 1]
              if feat_id <= last_feat_id:
                  sys.exit(sys.argv[0] + ": features must be listed in increasing order: {0} <= {1} in {2}.".format(feat_id, last_feat_id, line))
              last_feat_id = feat_id
  
              if feat_id > max_feat_id:
                  sys.exit(sys.argv[0] + ": Wrong feat_id: {0}.".format(line))
              elif feat_id in special_feat_ids:
                  if len(fields) != 3 and len(fields) != 5:
                      sys.exit(sys.argv[0] + ": Special word can only have one or 2 features: {0}.".format(line))
                  try:
                      float(feat_value)
                  except ValueError:
                      sys.exit(sys.argv[0] + ": Value of special word feature should be a float number: {0}.".format(line))
              elif feat_id == constant_feat_id:
                  if abs(float(feat_value) - constant_feat_value) > 1e-6:
                      sys.exit(sys.argv[0] + ": Value of constant feature is not right: {0}".format(line))
              else: # all feat_value would be float
                  try:
                      float(feat_value)
                  except ValueError:
                      sys.exit(sys.argv[0] + ": Value of unigram feature should be a float number: {0}.".format(line))
              i += 2