Blame view
scripts/rnnlm/validate_features.py
3.06 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
#!/usr/bin/env python3 # Copyright 2017 Jian Wang # License: Apache 2.0. import os import argparse import sys import re parser = argparse.ArgumentParser(description="Validates features file, produced by rnnlm/choose_features.py.", epilog="E.g. " + sys.argv[0] + " exp/rnnlm/features.txt", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("features_file", help="File containing features") args = parser.parse_args() EOS_SYMBOL = '</s>' if not os.path.isfile(args.features_file): sys.exit(sys.argv[0] + ": Expected file {0} to exist".format(args.features_file)) with open(args.features_file, 'r', encoding="utf-8") as f: has_unigram = False has_length = False idx = 0 match_feats = {} inital_feats = {} final_feats = {} word_feats = {} for line in f: fields = line.split() assert(len(fields) in [3, 4, 5]) assert idx == int(fields[0]) idx += 1 # every feature should contain a scale scale = float(fields[-1]) assert scale > 0.0 and scale <= 1.0 if len(fields) == 3 and fields[1] == "length": if has_length: sys.exit(sys.argv[0] + ": Too many 'length' features") has_length = True else: if fields[1] == "constant": try: assert len(fields) == 3 value = float(fields[2]) assert value > 0.0 except: sys.exit(sys.argv[0] + ": bad line: {0}".format(line)) elif fields[1] == "special": if len(fields) != 4: sys.exit(sys.argv[0] + ": bad line: {0}".format(line)) elif fields[1] == "unigram": if float(fields[2]) <= 0.0: sys.exit(sys.argv[0] + ": log-unigram-ppl should be a positive value: {0}".format(fields[2])) if has_unigram: sys.exit(sys.argv[0] + ": Too many 'unigram' features") has_unigram = True elif fields[1] == "word": if fields[2] in word_feats: sys.exit(sys.argv[0] + ": duplicated word feature: {0}".format(fields[2])) word_feats[fields[2]] = 1 elif fields[1] == "initial": if fields[2] in inital_feats: sys.exit(sys.argv[0] + ": duplicated initial feature: {0}".format(fields[2])) inital_feats[fields[2]] = 1 elif fields[1] == "final": if fields[2] in final_feats: sys.exit(sys.argv[0] + ": duplicated final feature: {0}".format(fields[2])) final_feats[fields[2]] = 1 elif fields[1] == "match": if fields[2] in match_feats: sys.exit(sys.argv[0] + ": duplicated match feature: {0}".format(fields[2])) match_feats[fields[2]] = 1 else: sys.exit(sys.argv[0] + ": Error line format: {0}".format(line)) |