get_best_model.py
2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (author: Daniel Povey)
# License: Apache 2.0.
import argparse
import glob
import re
import sys
parser = argparse.ArgumentParser(description="Works out the best iteration of RNNLM training "
"based on dev-set perplexity, and prints the number corresponding "
"to that iteration",
epilog="E.g. " + sys.argv[0] + " exp/rnnlm_a",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("rnnlm_dir",
help="Directory where the RNNLM has been trained")
args = parser.parse_args()
num_iters = None
try:
with open(args.rnnlm_dir + "/info.txt", encoding="utf-8") as f:
for line in f:
a = line.split("=")
if a[0] == "num_iters":
num_iters = int(a[1])
break
except:
sys.exit(sys.argv[0] + ": could not find {0}/info.txt (or wrong format)".format(
args.rnnlm_dir))
if num_iters is None:
sys.exit(sys.argv[0] + ": could not get num_iters from {0}/info.txt".format(
args.rnnlm_dir))
best_objf = -2000
best_iter = -1
for i in range(1, num_iters):
this_logfile = "{0}/log/compute_prob.{1}.log".format(args.rnnlm_dir, i)
try:
f = open(this_logfile, 'r', encoding='utf-8')
except:
sys.exit(sys.argv[0] + ": could not open log-file {0}".format(this_logfile))
this_objf = -1000
for line in f:
m = re.search('Overall objf .* (\S+)$', str(line))
if m is not None:
try:
this_objf = float(m.group(1))
except Exception as e:
sys.exit(sys.argv[0] + ": line in file {0} could not be parsed: {1}, error is: {2}".format(
this_logfile, line, str(e)))
# verify this iteration still has model files present
if len(glob.glob("{0}/{1}.raw".format(args.rnnlm_dir, i))) == 0:
# this iteration has log files, but model files have been cleaned up, skip it
continue
if this_objf == -1000:
print(sys.argv[0] + ": warning: could not parse objective function from {0}".format(
this_logfile), file=sys.stderr)
if this_objf > best_objf:
best_objf = this_objf
best_iter = i
if best_iter == -1:
sys.exit(sys.argv[0] + ": error: could not get best iteration.")
print(str(best_iter))