Blame view

egs/wsj/s5/steps/nnet3/get_successful_models.py 2.61 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  #!/usr/bin/env python
  
  from __future__ import print_function
  import re
  import os
  import argparse
  import sys
  import warnings
  import copy
  import glob
  
  
  if __name__ == "__main__":
      # we add compulsory arguments as named arguments for readability
      parser = argparse.ArgumentParser(description="Create a list of models suitable for averaging "
                                                   "based on their train objf values.",
                                       epilog="See steps/nnet3/lstm/train.sh for example.")
  
      parser.add_argument("--difference-threshold", type=float,
                          help="The threshold for discarding models, "
                          "when objf of the model differs more than this value from the best model "
                          "it is discarded.",
                          default=1.0)
  
      parser.add_argument("num_models", type=int,
                          help="Number of models.")
  
      parser.add_argument("logfile_pattern", type=str,
                          help="Pattern for identifying the log-file names. "
                          "It specifies the entire log file name, except for the job number, "
                          "which is replaced with '%'. e.g. exp/nneet3/tdnn_sp/log/train.4.%.log")
  
  
      args = parser.parse_args()
  
      assert(args.num_models > 0)
  
      parse_regex = re.compile("LOG .* Overall average objective function for 'output' is ([0-9e.\-+]+) over ([0-9e.\-+]+) frames")
      loss = []
      for i in range(args.num_models):
          model_num = i + 1
          logfile = re.sub('%', str(model_num), args.logfile_pattern)
          lines = open(logfile, 'r').readlines()
          this_loss = -100000
          for line_num in range(1, len(lines) + 1):
              # we search from the end as this would result in
              # lesser number of regex searches. Python regex is slow !
              mat_obj = parse_regex.search(lines[-1*line_num])
              if mat_obj is not None:
                  this_loss = float(mat_obj.groups()[0])
                  break;
          loss.append(this_loss);
      max_index = loss.index(max(loss))
      accepted_models = []
      for i in range(args.num_models):
          if (loss[max_index] - loss[i]) <= args.difference_threshold:
              accepted_models.append(i+1)
  
      model_list = " ".join([str(x) for x in accepted_models])
      print(model_list)
  
      if len(accepted_models) != args.num_models:
          print("WARNING: Only {0}/{1} of the models have been accepted for averaging, based on log files {2}.".format(len(accepted_models), args.num_models, args.logfile_pattern), file=sys.stderr)
          print("         Using models {0}".format(model_list), file=sys.stderr)