Blame view

egs/wsj/s5/steps/libs/nnet3/train/dropout_schedule.py 13.1 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
  #! /usr/bin/env python
  
  # Copyright 2016    Vimal Manohar
  # Apache 2.0
  
  """This module contains methods related to scheduling dropout.
  See _self_test() for examples of how the functions work.
  """
  
  import logging
  
  logger = logging.getLogger(__name__)
  logger.addHandler(logging.NullHandler())
  
  
  _debug_dropout = False
  
  def _parse_dropout_option(dropout_option):
      """Parses the string option to --trainer.dropout-schedule and
      returns a list of dropout schedules for different component name patterns.
      Calls _parse_dropout_string() function for each component name pattern
      in the option.
  
      Arguments:
          dropout_option: The string option passed to --trainer.dropout-schedule.
              See its help for details.
              See _self_test() for examples.
          num_archive_to_process: See _parse_dropout_string() for details.
  
      Returns a list of (component_name, dropout_schedule) tuples,
      where dropout_schedule is itself a list of
      (data_fraction, dropout_proportion) tuples sorted in reverse order of
      data_fraction.
      A data fraction of 0 corresponds to beginning of training
      and 1 corresponds to all data.
      """
      components = dropout_option.strip().split(' ')
      dropout_schedule = []
      for component in components:
          parts = component.split('=')
  
          if len(parts) == 2:
              component_name = parts[0]
              this_dropout_str = parts[1]
          elif len(parts) == 1:
              component_name = '*'
              this_dropout_str = parts[0]
          else:
              raise Exception("The dropout schedule must be specified in the "
                              "format 'pattern1=func1 patter2=func2' where "
                              "the pattern can be omitted for a global function "
                              "for all components.
  "
                              "Got {0} in {1}".format(component, dropout_option))
  
          this_dropout_values = _parse_dropout_string(this_dropout_str)
          dropout_schedule.append((component_name, this_dropout_values))
  
      if _debug_dropout:
          logger.info("Dropout schedules for component names is as follows:")
          logger.info("<component-name-pattern>: [(num_archives_processed), "
                      "(dropout_proportion) ...]")
          for name, schedule in dropout_schedule:
              logger.info("{0}: {1}".format(name, schedule))
  
      return dropout_schedule
  
  
  def _parse_dropout_string(dropout_str):
      """Parses the dropout schedule from the string corresponding to a
      single component in --trainer.dropout-schedule.
      This is a module-internal function called by parse_dropout_function().
  
      Arguments:
          dropout_str: Specifies dropout schedule for a particular component
              name pattern.
              See help for the option --trainer.dropout-schedule.
  
      Returns a list of (data_fraction_processed, dropout_proportion) tuples
      sorted in descending order of num_archives_processed.
      A data fraction of 1 corresponds to all data.
      """
      dropout_values = []
      parts = dropout_str.strip().split(',')
  
      try:
          if len(parts) < 2:
              raise Exception("dropout proportion string must specify "
                              "at least the start and end dropouts")
  
          # Starting dropout proportion
          dropout_values.append((0, float(parts[0])))
          for i in range(1, len(parts) - 1):
              value_x_pair = parts[i].split('@')
              if len(value_x_pair) == 1:
                  # Dropout proportion at half of training
                  dropout_proportion = float(value_x_pair[0])
                  data_fraction = 0.5
              else:
                  assert len(value_x_pair) == 2
  
                  dropout_proportion = float(value_x_pair[0])
                  data_fraction = float(value_x_pair[1])
  
              if (data_fraction < dropout_values[-1][0]
                      or data_fraction > 1.0):
                  logger.error(
                      "Failed while parsing value %s in dropout-schedule. "
                      "dropout-schedule must be in incresing "
                      "order of data fractions.", value_x_pair)
                  raise ValueError
  
              dropout_values.append((data_fraction, float(dropout_proportion)))
  
          dropout_values.append((1.0, float(parts[-1])))
      except Exception:
          logger.error("Unable to parse dropout proportion string %s. "
                       "See help for option "
                       "--trainer.dropout-schedule.", dropout_str)
          raise
  
      # reverse sort so that its easy to retrieve the dropout proportion
      # for a particular data fraction
      dropout_values.reverse()
      for data_fraction, proportion in dropout_values:
          assert data_fraction <= 1.0 and data_fraction >= 0.0
          assert proportion <= 1.0 and proportion >= 0.0
  
      return dropout_values
  
  
  def _get_component_dropout(dropout_schedule, data_fraction):
      """Retrieve dropout proportion from schedule when data_fraction
      proportion of data is seen. This value is obtained by using a
      piecewise linear function on the dropout schedule.
      This is a module-internal function called by _get_dropout_proportions().
  
      See help for --trainer.dropout-schedule for how the dropout value
      is obtained from the options.
  
      Arguments:
          dropout_schedule: A list of (data_fraction, dropout_proportion) values
              sorted in descending order of data_fraction.
          data_fraction: The fraction of data seen until this stage of
              training.
      """
      if data_fraction == 0:
          # Dropout at start of the iteration is in the last index of
          # dropout_schedule
          assert dropout_schedule[-1][0] == 0
          return dropout_schedule[-1][1]
      try:
          # Find lower bound of the data_fraction. This is the
          # lower end of the piecewise linear function.
          (dropout_schedule_index, initial_data_fraction,
           initial_dropout) = next((i, tup[0], tup[1])
                                   for i, tup in enumerate(dropout_schedule)
                                   if tup[0] <= data_fraction)
      except StopIteration:
          raise RuntimeError(
              "Could not find data_fraction in dropout schedule "
              "corresponding to data_fraction {0}.
  "
              "Maybe something wrong with the parsed "
              "dropout schedule {1}.".format(data_fraction, dropout_schedule))
  
      if dropout_schedule_index == 0:
          assert dropout_schedule[0][0] == 1 and data_fraction == 1
          return dropout_schedule[0][1]
  
      # The upper bound of data_fraction is at the index before the
      # lower bound.
      final_data_fraction, final_dropout = dropout_schedule[
          dropout_schedule_index - 1]
  
      if final_data_fraction == initial_data_fraction:
          assert data_fraction == initial_data_fraction
          return initial_dropout
  
      assert (data_fraction >= initial_data_fraction
              and data_fraction < final_data_fraction)
  
      return ((data_fraction - initial_data_fraction)
              * (final_dropout - initial_dropout)
              / (final_data_fraction - initial_data_fraction)
              + initial_dropout)
  
  
  def _get_dropout_proportions(dropout_schedule, data_fraction):
      """Returns dropout proportions based on the dropout_schedule for the
      fraction of data seen at this stage of training.  Returns a list of
      pairs (pattern, dropout_proportion); for instance, it might return
      the list ['*', 0.625] meaning a dropout proportion of 0.625 is to
      be applied to all dropout components.
  
      Returns None if dropout_schedule is None.
  
      dropout_schedule might be (in the sample case using the default pattern of
      '*'): '0.1,0.5@0.5,0.1', meaning a piecewise linear function that starts at
      0.1 when data_fraction=0.0, rises to 0.5 when data_fraction=0.5, and falls
      again to 0.1 when data_fraction=1.0.   It can also contain space-separated
      items of the form 'pattern=schedule', for instance:
         '*=0.0,0.5,0.0 lstm.*=0.0,0.3@0.75,0.0'
      The more specific patterns should go later, otherwise they will be overridden
      by the less specific patterns' commands.
  
      Calls _get_component_dropout() for the different component name patterns
      in dropout_schedule.
  
      Arguments:
          dropout_schedule: Value for the --trainer.dropout-schedule option.
              See help for --trainer.dropout-schedule.
              See _self_test() for examples.
          data_fraction: The fraction of data seen until this stage of
              training.
  
      """
      if dropout_schedule is None:
          return None
      dropout_schedule = _parse_dropout_option(dropout_schedule)
      dropout_proportions = []
      for component_name, component_dropout_schedule in dropout_schedule:
          dropout_proportions.append(
              (component_name, _get_component_dropout(
                  component_dropout_schedule, data_fraction)))
      return dropout_proportions
  
  
  def get_dropout_edit_string(dropout_schedule, data_fraction, iter_):
      """Return an nnet3-copy --edits line to modify raw_model_string to
      set dropout proportions according to dropout_proportions.
      E.g. if _dropout_proportions(dropout_schedule, data_fraction)
      returns [('*', 0.625)],  this will return the string:
       "nnet3-copy --edits='set-dropout-proportion name=* proportion=0.625'"
  
  
      Arguments:
          dropout_schedule: Value for the --trainer.dropout-schedule option.
              See help for --trainer.dropout-schedule.
              See _self_test() for examples.
  
      See ReadEditConfig() in nnet3/nnet-utils.h to see how
      set-dropout-proportion directive works.
      """
  
      if dropout_schedule is None:
          return ""
  
      dropout_proportions = _get_dropout_proportions(
          dropout_schedule, data_fraction)
  
      edit_config_lines = []
      dropout_info = []
  
      for component_name, dropout_proportion in dropout_proportions:
          edit_config_lines.append(
              "set-dropout-proportion name={0} proportion={1}".format(
                  component_name, dropout_proportion))
          dropout_info.append("pattern/dropout-proportion={0}/{1}".format(
              component_name, dropout_proportion))
  
      if _debug_dropout:
          logger.info("On iteration %d, %s", iter_, ', '.join(dropout_info))
      return ("""nnet3-copy --edits='{edits}' - - |""".format(
          edits=";".join(edit_config_lines)))
  
  
  def _self_test():
      """Run self-test.
      This method is called if the module is run as a standalone script.
      """
  
      def assert_approx_equal(list1, list2):
          """Checks that the two dropout proportions lists are equal."""
          assert len(list1) == len(list2)
          for i in range(0, len(list1)):
              assert len(list1[i]) == 2
              assert len(list2[i]) == 2
              assert list1[i][0] == list2[i][0]
              assert abs(list1[i][1] - list2[i][1]) < 1e-8
  
      assert (_parse_dropout_option('*=0.0,0.5,0.0 lstm.*=0.0,0.3@0.75,0.0')
              == [ ('*', [ (1.0, 0.0), (0.5, 0.5), (0.0, 0.0) ]),
                   ('lstm.*', [ (1.0, 0.0), (0.75, 0.3), (0.0, 0.0) ]) ])
      assert_approx_equal(_get_dropout_proportions(
                             '*=0.0,0.5,0.0 lstm.*=0.0,0.3@0.75,0.0', 0.75),
                          [ ('*', 0.25), ('lstm.*', 0.3) ])
      assert_approx_equal(_get_dropout_proportions(
                              '*=0.0,0.5,0.0 lstm.*=0.0,0.3@0.75,0.0', 0.5),
                          [ ('*', 0.5), ('lstm.*', 0.2) ])
      assert_approx_equal(_get_dropout_proportions(
                              '*=0.0,0.5,0.0 lstm.*=0.0,0.3@0.75,0.0', 0.25),
                          [ ('*', 0.25), ('lstm.*', 0.1) ])
  
      assert (_parse_dropout_option('0.0,0.3,0.0')
              == [ ('*', [ (1.0, 0.0), (0.5, 0.3), (0.0, 0.0) ]) ])
      assert_approx_equal(_get_dropout_proportions('0.0,0.3,0.0', 0.5),
                          [ ('*', 0.3) ])
      assert_approx_equal(_get_dropout_proportions('0.0,0.3,0.0', 0.0),
                          [ ('*', 0.0) ])
      assert_approx_equal(_get_dropout_proportions('0.0,0.3,0.0', 1.0),
                          [ ('*', 0.0) ])
      assert_approx_equal(_get_dropout_proportions('0.0,0.3,0.0', 0.25),
                          [ ('*', 0.15) ])
  
      assert (_parse_dropout_option('0.0,0.5@0.25,0.0,0.6@0.75,0.0')
              == [ ('*', [ (1.0, 0.0), (0.75, 0.6), (0.5, 0.0), (0.25, 0.5), (0.0, 0.0) ]) ])
      assert_approx_equal(_get_dropout_proportions(
                              '0.0,0.5@0.25,0.0,0.6@0.75,0.0', 0.25),
                          [ ('*', 0.5) ])
      assert_approx_equal(_get_dropout_proportions(
                              '0.0,0.5@0.25,0.0,0.6@0.75,0.0', 0.1),
                          [ ('*', 0.2) ])
  
      assert (_parse_dropout_option('lstm.*=0.0,0.3,0.0@0.75,1.0')
              == [ ('lstm.*', [ (1.0, 1.0), (0.75, 0.0), (0.5, 0.3), (0.0, 0.0) ]) ])
      assert_approx_equal(_get_dropout_proportions(
                              'lstm.*=0.0,0.3,0.0@0.75,1.0', 0.25),
                          [ ('lstm.*', 0.15) ])
      assert_approx_equal(_get_dropout_proportions(
                              'lstm.*=0.0,0.3,0.0@0.75,1.0', 0.5),
                          [ ('lstm.*', 0.3) ])
      assert_approx_equal(_get_dropout_proportions(
                              'lstm.*=0.0,0.3,0.0@0.75,1.0', 0.9),
                          [ ('lstm.*', 0.6) ])
  
  
  if __name__ == '__main__':
      try:
          _self_test()
      except Exception:
          logger.error("Failed self test")
          raise