Blame view

scripts/rnnlm/rnnlm_cleanup.py 6.46 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
  #!/usr/bin/env python3
  
  # Copyright 2018 Tilde
  # License: Apache 2.0
  
  import sys
  
  import argparse
  import os
  import re
  import glob
  
  script_name = sys.argv[0]
  
  parser = argparse.ArgumentParser(description="Removes models from past training iterations of "
                                               "RNNLM. Can use either 'keep_latest' (default) or "
                                               "'keep_best' cleanup strategy, where former keeps "
                                               "the models that are freshest, while latter keeps "
                                               "the models with best training objective score on "
                                               "dev set.",
                                   epilog="E.g. " + script_name + " exp/rnnlm_a --keep_best",
                                   formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  
  parser.add_argument("rnnlm_dir",
                      help="Directory where the RNNLM has been trained")
  parser.add_argument("--iters_to_keep",
                      help="Max number of iterations to keep",
                      type=int,
                      default=3)
  parser.add_argument("--keep_latest",
                      help="Keeps the training iterations that are latest by age",
                      action="store_const",
                      const=True,
                      default=False)
  parser.add_argument("--keep_best",
                      help="Keeps the training iterations that have the best objf",
                      action="store_const",
                      const=True,
                      default=False)
  
  args = parser.parse_args()
  
  # validate arguments
  if args.keep_latest and args.keep_best:
      sys.exit(script_name + ": can only use one of 'keep_latest' or 'keep_best', but not both")
  elif not args.keep_latest and not args.keep_best:
      sys.exit(script_name + ": no cleanup strategy specified: use 'keep_latest' or 'keep_best'")
  
  
  class IterationInfo:
      def __init__(self, model_files, objf, compute_prob_done):
          self.model_files = model_files
          self.objf = objf
          self.compute_prob_done = compute_prob_done
  
      def __str__(self):
          return "{model_files: %s, compute_prob: %s, objf: %2.3f}" % (self.model_files,
                                                                       self.compute_prob_done,
                                                                       self.objf)
  
      def __repr__(self):
          return self.__str__()
  
  
  def get_compute_prob_info(log_file):
      # we want to know 3 things: iteration number, objf and whether compute prob is done
      iteration = int(log_file.split(".")[-2])
      objf = -2000
      compute_prob_done = False
      # roughly based on code in get_best_model.py
      try:
          f = open(log_file, "r", encoding="utf-8")
      except:
          print(script_name + ": warning: compute_prob log not found for iteration " +
                str(iter) + ". Skipping",
                file=sys.stderr)
          return iteration, objf, compute_prob_done
      for line in f:
          objf_m = re.search('Overall objf .* (\S+)$', str(line))
          if objf_m is not None:
              try:
                  objf = float(objf_m.group(1))
              except Exception as e:
                  sys.exit(script_name + ": line in file {0} could not be parsed: {1}, error is: {2}".format(
                      log_file, line, str(e)))
          if "# Ended" in line:
              compute_prob_done = True
      if objf == -2000:
          print(script_name + ": warning: could not parse objective function from " + log_file, file=sys.stderr)
      return iteration, objf, compute_prob_done
  
  
  def get_iteration_files(exp_dir):
      iterations = dict()
      compute_prob_logs = glob.glob(exp_dir + "/log/compute_prob.[0-9]*.log")
      for log in compute_prob_logs:
          iteration, objf, compute_prob_done = get_compute_prob_info(log)
          if iteration == 0:
              # iteration 0 is special, never consider it for cleanup
              continue
          if compute_prob_done:
              # this iteration can be safely considered for cleanup
              # gather all model files belonging to it
              model_files = []
              # when there are multiple jobs per iteration, there can be several model files
              # we need to potentially clean them all up without mixing them up
              model_files.extend(glob.glob("{0}/word_embedding.{1}.mat".format(exp_dir, iteration)))
              model_files.extend(glob.glob("{0}/word_embedding.{1}.[0-9]*.mat".format(exp_dir, iteration)))
              model_files.extend(glob.glob("{0}/feat_embedding.{1}.mat".format(exp_dir, iteration)))
              model_files.extend(glob.glob("{0}/feat_embedding.{1}.[0-9]*.mat".format(exp_dir, iteration)))
              model_files.extend(glob.glob("{0}/{1}.raw".format(exp_dir, iteration)))
              model_files.extend(glob.glob("{0}/{1}.[0-9]*.raw".format(exp_dir, iteration)))
              # compute_prob logs outlive model files, only consider iterations that do still have model files
              if len(model_files) > 0:
                  iterations[iteration] = IterationInfo(model_files, objf, compute_prob_done)
      return iterations
  
  
  def remove_model_files_for_iter(iter_info):
      for f in iter_info.model_files:
          os.remove(f)
  
  
  def keep_latest(iteration_dict):
      max_to_keep = args.iters_to_keep
      kept = 0
      iterations_in_reverse_order = reversed(sorted(iteration_dict))
      for iter in iterations_in_reverse_order:
          if kept < max_to_keep:
              kept += 1
          else:
              remove_model_files_for_iter(iteration_dict[iter])
  
  
  def keep_best(iteration_dict):
      iters_to_keep = args.iters_to_keep
      best = []
      for iter, iter_info in iteration_dict.items():
          objf = iter_info.objf
          if objf == -2000:
              print(script_name + ": warning: objf unavailable for iter " + str(iter), file=sys.stderr)
              continue
          # add potential best, sort by objf, trim to iters_to_keep size
          best.append((iter, objf))
          best = sorted(best, key=lambda x: -x[1])
          if len(best) > iters_to_keep:
              throwaway = best[iters_to_keep:]
              best = best[:iters_to_keep]
              # remove iters that we know are not the best
              for (iter, _) in throwaway:
                  remove_model_files_for_iter(iteration_dict[iter])
  
  
  # grab all the iterations mapped to their model files, objf score and compute_prob status
  iterations = get_iteration_files(args.rnnlm_dir)
  # apply chosen cleanup strategy
  if args.keep_latest:
      keep_latest(iterations)
  else:
      keep_best(iterations)