get_successful_models.py
2.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/env python
from __future__ import print_function
import re
import os
import argparse
import sys
import warnings
import copy
import glob
if __name__ == "__main__":
# we add compulsory arguments as named arguments for readability
parser = argparse.ArgumentParser(description="Create a list of models suitable for averaging "
"based on their train objf values.",
epilog="See steps/nnet3/lstm/train.sh for example.")
parser.add_argument("--difference-threshold", type=float,
help="The threshold for discarding models, "
"when objf of the model differs more than this value from the best model "
"it is discarded.",
default=1.0)
parser.add_argument("num_models", type=int,
help="Number of models.")
parser.add_argument("logfile_pattern", type=str,
help="Pattern for identifying the log-file names. "
"It specifies the entire log file name, except for the job number, "
"which is replaced with '%'. e.g. exp/nneet3/tdnn_sp/log/train.4.%.log")
args = parser.parse_args()
assert(args.num_models > 0)
parse_regex = re.compile("LOG .* Overall average objective function for 'output' is ([0-9e.\-+]+) over ([0-9e.\-+]+) frames")
loss = []
for i in range(args.num_models):
model_num = i + 1
logfile = re.sub('%', str(model_num), args.logfile_pattern)
lines = open(logfile, 'r').readlines()
this_loss = -100000
for line_num in range(1, len(lines) + 1):
# we search from the end as this would result in
# lesser number of regex searches. Python regex is slow !
mat_obj = parse_regex.search(lines[-1*line_num])
if mat_obj is not None:
this_loss = float(mat_obj.groups()[0])
break;
loss.append(this_loss);
max_index = loss.index(max(loss))
accepted_models = []
for i in range(args.num_models):
if (loss[max_index] - loss[i]) <= args.difference_threshold:
accepted_models.append(i+1)
model_list = " ".join([str(x) for x in accepted_models])
print(model_list)
if len(accepted_models) != args.num_models:
print("WARNING: Only {0}/{1} of the models have been accepted for averaging, based on log files {2}.".format(len(accepted_models), args.num_models, args.logfile_pattern), file=sys.stderr)
print(" Using models {0}".format(model_list), file=sys.stderr)