Blame view
scripts/rnnlm/validate_word_features.py
3.37 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
#!/usr/bin/env python3 # Copyright 2017 Jian Wang # License: Apache 2.0. import os import argparse import sys import re parser = argparse.ArgumentParser(description="Validates word features file, produced by rnnlm/get_word_features.py.", epilog="E.g. " + sys.argv[0] + " --features-file=exp/rnnlm/features.txt " "exp/rnnlm/word_feats.txt", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--features-file", type=str, default='', required=True, help="File containing features") parser.add_argument("word_features_file", help="File containing word features") args = parser.parse_args() # we only need to know the feat_id for 'special', 'unigram' and 'length' special_feat_ids = [] constant_feat_id = -1 constant_feat_value = None unigram_feat_id = -1 length_feat_id = -1 max_feat_id = -1 with open(args.features_file, 'r', encoding="utf-8") as f: for line in f: fields = line.split() assert(len(fields) in [3, 4, 5]) feat_id = int(fields[0]) # every feature should contain a scale scale = float(fields[-1]) assert scale > 0.0 and scale <= 1.0 if fields[1] == "special": special_feat_ids.append(feat_id) elif fields[1] == "constant": constant_feat_id = feat_id constant_feat_value = scale elif fields[1] == "unigram": unigram_feat_id = feat_id elif fields[1] == "length": length_feat_id = feat_id if feat_id > max_feat_id: max_feat_id = feat_id with open(args.word_features_file, 'r', encoding="utf-8") as f: for line in f: fields = line.split() assert len(fields) > 0 and len(fields) % 2 == 1 word_id = int(fields[0]) if len(fields) == 1: if word_id != 0: sys.exit(sys.argv[0] + ": Only <eps> can have no feature: {0}.".format(line)) i = 1 last_feat_id = -1 while i < len(fields): feat_id = int(fields[i]) feat_value = fields[i + 1] if feat_id <= last_feat_id: sys.exit(sys.argv[0] + ": features must be listed in increasing order: {0} <= {1} in {2}.".format(feat_id, last_feat_id, line)) last_feat_id = feat_id if feat_id > max_feat_id: sys.exit(sys.argv[0] + ": Wrong feat_id: {0}.".format(line)) elif feat_id in special_feat_ids: if len(fields) != 3 and len(fields) != 5: sys.exit(sys.argv[0] + ": Special word can only have one or 2 features: {0}.".format(line)) try: float(feat_value) except ValueError: sys.exit(sys.argv[0] + ": Value of special word feature should be a float number: {0}.".format(line)) elif feat_id == constant_feat_id: if abs(float(feat_value) - constant_feat_value) > 1e-6: sys.exit(sys.argv[0] + ": Value of constant feature is not right: {0}".format(line)) else: # all feat_value would be float try: float(feat_value) except ValueError: sys.exit(sys.argv[0] + ": Value of unigram feature should be a float number: {0}.".format(line)) i += 2 |