Blame view

scripts/rnnlm/get_best_model.py 2.44 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
  #!/usr/bin/env python3
  
  # Copyright  2017  Johns Hopkins University (author: Daniel Povey)
  # License: Apache 2.0.
  
  import argparse
  import glob
  import re
  import sys
  
  parser = argparse.ArgumentParser(description="Works out the best iteration of RNNLM training "
                                               "based on dev-set perplexity, and prints the number corresponding "
                                               "to that iteration",
                                   epilog="E.g. " + sys.argv[0] + " exp/rnnlm_a",
                                   formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  
  parser.add_argument("rnnlm_dir",
                      help="Directory where the RNNLM has been trained")
  
  args = parser.parse_args()
  
  num_iters = None
  try:
      with open(args.rnnlm_dir + "/info.txt", encoding="utf-8") as f:
          for line in f:
              a = line.split("=")
              if a[0] == "num_iters":
                  num_iters = int(a[1])
                  break
  except:
      sys.exit(sys.argv[0] + ": could not find {0}/info.txt (or wrong format)".format(
          args.rnnlm_dir))
  
  if num_iters is None:
      sys.exit(sys.argv[0] + ": could not get num_iters from {0}/info.txt".format(
          args.rnnlm_dir))
  
  best_objf = -2000
  best_iter = -1
  for i in range(1, num_iters):
      this_logfile = "{0}/log/compute_prob.{1}.log".format(args.rnnlm_dir, i)
      try:
          f = open(this_logfile, 'r', encoding='utf-8')
      except:
          sys.exit(sys.argv[0] + ": could not open log-file {0}".format(this_logfile))
      this_objf = -1000
      for line in f:
          m = re.search('Overall objf .* (\S+)$', str(line))
          if m is not None:
              try:
                  this_objf = float(m.group(1))
              except Exception as e:
                  sys.exit(sys.argv[0] + ": line in file {0} could not be parsed: {1}, error is: {2}".format(
                      this_logfile, line, str(e)))
      # verify this iteration still has model files present
      if len(glob.glob("{0}/{1}.raw".format(args.rnnlm_dir, i))) == 0:
          # this iteration has log files, but model files have been cleaned up, skip it
          continue
      if this_objf == -1000:
          print(sys.argv[0] + ": warning: could not parse objective function from {0}".format(
              this_logfile), file=sys.stderr)
      if this_objf > best_objf:
          best_objf = this_objf
          best_iter = i
  
  if best_iter == -1:
      sys.exit(sys.argv[0] + ": error: could not get best iteration.")
  
  print(str(best_iter))