Blame view
scripts/rnnlm/get_best_model.py
2.44 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
#!/usr/bin/env python3 # Copyright 2017 Johns Hopkins University (author: Daniel Povey) # License: Apache 2.0. import argparse import glob import re import sys parser = argparse.ArgumentParser(description="Works out the best iteration of RNNLM training " "based on dev-set perplexity, and prints the number corresponding " "to that iteration", epilog="E.g. " + sys.argv[0] + " exp/rnnlm_a", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("rnnlm_dir", help="Directory where the RNNLM has been trained") args = parser.parse_args() num_iters = None try: with open(args.rnnlm_dir + "/info.txt", encoding="utf-8") as f: for line in f: a = line.split("=") if a[0] == "num_iters": num_iters = int(a[1]) break except: sys.exit(sys.argv[0] + ": could not find {0}/info.txt (or wrong format)".format( args.rnnlm_dir)) if num_iters is None: sys.exit(sys.argv[0] + ": could not get num_iters from {0}/info.txt".format( args.rnnlm_dir)) best_objf = -2000 best_iter = -1 for i in range(1, num_iters): this_logfile = "{0}/log/compute_prob.{1}.log".format(args.rnnlm_dir, i) try: f = open(this_logfile, 'r', encoding='utf-8') except: sys.exit(sys.argv[0] + ": could not open log-file {0}".format(this_logfile)) this_objf = -1000 for line in f: m = re.search('Overall objf .* (\S+)$', str(line)) if m is not None: try: this_objf = float(m.group(1)) except Exception as e: sys.exit(sys.argv[0] + ": line in file {0} could not be parsed: {1}, error is: {2}".format( this_logfile, line, str(e))) # verify this iteration still has model files present if len(glob.glob("{0}/{1}.raw".format(args.rnnlm_dir, i))) == 0: # this iteration has log files, but model files have been cleaned up, skip it continue if this_objf == -1000: print(sys.argv[0] + ": warning: could not parse objective function from {0}".format( this_logfile), file=sys.stderr) if this_objf > best_objf: best_objf = this_objf best_iter = i if best_iter == -1: sys.exit(sys.argv[0] + ": error: could not get best iteration.") print(str(best_iter)) |