validate_word_features.py
3.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
# Copyright 2017 Jian Wang
# License: Apache 2.0.
import os
import argparse
import sys
import re
parser = argparse.ArgumentParser(description="Validates word features file, produced by rnnlm/get_word_features.py.",
epilog="E.g. " + sys.argv[0] + " --features-file=exp/rnnlm/features.txt "
"exp/rnnlm/word_feats.txt",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--features-file", type=str, default='', required=True,
help="File containing features")
parser.add_argument("word_features_file", help="File containing word features")
args = parser.parse_args()
# we only need to know the feat_id for 'special', 'unigram' and 'length'
special_feat_ids = []
constant_feat_id = -1
constant_feat_value = None
unigram_feat_id = -1
length_feat_id = -1
max_feat_id = -1
with open(args.features_file, 'r', encoding="utf-8") as f:
for line in f:
fields = line.split()
assert(len(fields) in [3, 4, 5])
feat_id = int(fields[0])
# every feature should contain a scale
scale = float(fields[-1])
assert scale > 0.0 and scale <= 1.0
if fields[1] == "special":
special_feat_ids.append(feat_id)
elif fields[1] == "constant":
constant_feat_id = feat_id
constant_feat_value = scale
elif fields[1] == "unigram":
unigram_feat_id = feat_id
elif fields[1] == "length":
length_feat_id = feat_id
if feat_id > max_feat_id:
max_feat_id = feat_id
with open(args.word_features_file, 'r', encoding="utf-8") as f:
for line in f:
fields = line.split()
assert len(fields) > 0 and len(fields) % 2 == 1
word_id = int(fields[0])
if len(fields) == 1:
if word_id != 0:
sys.exit(sys.argv[0] + ": Only <eps> can have no feature: {0}.".format(line))
i = 1
last_feat_id = -1
while i < len(fields):
feat_id = int(fields[i])
feat_value = fields[i + 1]
if feat_id <= last_feat_id:
sys.exit(sys.argv[0] + ": features must be listed in increasing order: {0} <= {1} in {2}.".format(feat_id, last_feat_id, line))
last_feat_id = feat_id
if feat_id > max_feat_id:
sys.exit(sys.argv[0] + ": Wrong feat_id: {0}.".format(line))
elif feat_id in special_feat_ids:
if len(fields) != 3 and len(fields) != 5:
sys.exit(sys.argv[0] + ": Special word can only have one or 2 features: {0}.".format(line))
try:
float(feat_value)
except ValueError:
sys.exit(sys.argv[0] + ": Value of special word feature should be a float number: {0}.".format(line))
elif feat_id == constant_feat_id:
if abs(float(feat_value) - constant_feat_value) > 1e-6:
sys.exit(sys.argv[0] + ": Value of constant feature is not right: {0}".format(line))
else: # all feat_value would be float
try:
float(feat_value)
except ValueError:
sys.exit(sys.argv[0] + ": Value of unigram feature should be a float number: {0}.".format(line))
i += 2