Blame view
egs/wsj/s5/steps/libs/nnet3/train/dropout_schedule.py
13.1 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 |
#! /usr/bin/env python # Copyright 2016 Vimal Manohar # Apache 2.0 """This module contains methods related to scheduling dropout. See _self_test() for examples of how the functions work. """ import logging logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) _debug_dropout = False def _parse_dropout_option(dropout_option): """Parses the string option to --trainer.dropout-schedule and returns a list of dropout schedules for different component name patterns. Calls _parse_dropout_string() function for each component name pattern in the option. Arguments: dropout_option: The string option passed to --trainer.dropout-schedule. See its help for details. See _self_test() for examples. num_archive_to_process: See _parse_dropout_string() for details. Returns a list of (component_name, dropout_schedule) tuples, where dropout_schedule is itself a list of (data_fraction, dropout_proportion) tuples sorted in reverse order of data_fraction. A data fraction of 0 corresponds to beginning of training and 1 corresponds to all data. """ components = dropout_option.strip().split(' ') dropout_schedule = [] for component in components: parts = component.split('=') if len(parts) == 2: component_name = parts[0] this_dropout_str = parts[1] elif len(parts) == 1: component_name = '*' this_dropout_str = parts[0] else: raise Exception("The dropout schedule must be specified in the " "format 'pattern1=func1 patter2=func2' where " "the pattern can be omitted for a global function " "for all components. " "Got {0} in {1}".format(component, dropout_option)) this_dropout_values = _parse_dropout_string(this_dropout_str) dropout_schedule.append((component_name, this_dropout_values)) if _debug_dropout: logger.info("Dropout schedules for component names is as follows:") logger.info("<component-name-pattern>: [(num_archives_processed), " "(dropout_proportion) ...]") for name, schedule in dropout_schedule: logger.info("{0}: {1}".format(name, schedule)) return dropout_schedule def _parse_dropout_string(dropout_str): """Parses the dropout schedule from the string corresponding to a single component in --trainer.dropout-schedule. This is a module-internal function called by parse_dropout_function(). Arguments: dropout_str: Specifies dropout schedule for a particular component name pattern. See help for the option --trainer.dropout-schedule. Returns a list of (data_fraction_processed, dropout_proportion) tuples sorted in descending order of num_archives_processed. A data fraction of 1 corresponds to all data. """ dropout_values = [] parts = dropout_str.strip().split(',') try: if len(parts) < 2: raise Exception("dropout proportion string must specify " "at least the start and end dropouts") # Starting dropout proportion dropout_values.append((0, float(parts[0]))) for i in range(1, len(parts) - 1): value_x_pair = parts[i].split('@') if len(value_x_pair) == 1: # Dropout proportion at half of training dropout_proportion = float(value_x_pair[0]) data_fraction = 0.5 else: assert len(value_x_pair) == 2 dropout_proportion = float(value_x_pair[0]) data_fraction = float(value_x_pair[1]) if (data_fraction < dropout_values[-1][0] or data_fraction > 1.0): logger.error( "Failed while parsing value %s in dropout-schedule. " "dropout-schedule must be in incresing " "order of data fractions.", value_x_pair) raise ValueError dropout_values.append((data_fraction, float(dropout_proportion))) dropout_values.append((1.0, float(parts[-1]))) except Exception: logger.error("Unable to parse dropout proportion string %s. " "See help for option " "--trainer.dropout-schedule.", dropout_str) raise # reverse sort so that its easy to retrieve the dropout proportion # for a particular data fraction dropout_values.reverse() for data_fraction, proportion in dropout_values: assert data_fraction <= 1.0 and data_fraction >= 0.0 assert proportion <= 1.0 and proportion >= 0.0 return dropout_values def _get_component_dropout(dropout_schedule, data_fraction): """Retrieve dropout proportion from schedule when data_fraction proportion of data is seen. This value is obtained by using a piecewise linear function on the dropout schedule. This is a module-internal function called by _get_dropout_proportions(). See help for --trainer.dropout-schedule for how the dropout value is obtained from the options. Arguments: dropout_schedule: A list of (data_fraction, dropout_proportion) values sorted in descending order of data_fraction. data_fraction: The fraction of data seen until this stage of training. """ if data_fraction == 0: # Dropout at start of the iteration is in the last index of # dropout_schedule assert dropout_schedule[-1][0] == 0 return dropout_schedule[-1][1] try: # Find lower bound of the data_fraction. This is the # lower end of the piecewise linear function. (dropout_schedule_index, initial_data_fraction, initial_dropout) = next((i, tup[0], tup[1]) for i, tup in enumerate(dropout_schedule) if tup[0] <= data_fraction) except StopIteration: raise RuntimeError( "Could not find data_fraction in dropout schedule " "corresponding to data_fraction {0}. " "Maybe something wrong with the parsed " "dropout schedule {1}.".format(data_fraction, dropout_schedule)) if dropout_schedule_index == 0: assert dropout_schedule[0][0] == 1 and data_fraction == 1 return dropout_schedule[0][1] # The upper bound of data_fraction is at the index before the # lower bound. final_data_fraction, final_dropout = dropout_schedule[ dropout_schedule_index - 1] if final_data_fraction == initial_data_fraction: assert data_fraction == initial_data_fraction return initial_dropout assert (data_fraction >= initial_data_fraction and data_fraction < final_data_fraction) return ((data_fraction - initial_data_fraction) * (final_dropout - initial_dropout) / (final_data_fraction - initial_data_fraction) + initial_dropout) def _get_dropout_proportions(dropout_schedule, data_fraction): """Returns dropout proportions based on the dropout_schedule for the fraction of data seen at this stage of training. Returns a list of pairs (pattern, dropout_proportion); for instance, it might return the list ['*', 0.625] meaning a dropout proportion of 0.625 is to be applied to all dropout components. Returns None if dropout_schedule is None. dropout_schedule might be (in the sample case using the default pattern of '*'): '0.1,0.5@0.5,0.1', meaning a piecewise linear function that starts at 0.1 when data_fraction=0.0, rises to 0.5 when data_fraction=0.5, and falls again to 0.1 when data_fraction=1.0. It can also contain space-separated items of the form 'pattern=schedule', for instance: '*=0.0,0.5,0.0 lstm.*=0.0,0.3@0.75,0.0' The more specific patterns should go later, otherwise they will be overridden by the less specific patterns' commands. Calls _get_component_dropout() for the different component name patterns in dropout_schedule. Arguments: dropout_schedule: Value for the --trainer.dropout-schedule option. See help for --trainer.dropout-schedule. See _self_test() for examples. data_fraction: The fraction of data seen until this stage of training. """ if dropout_schedule is None: return None dropout_schedule = _parse_dropout_option(dropout_schedule) dropout_proportions = [] for component_name, component_dropout_schedule in dropout_schedule: dropout_proportions.append( (component_name, _get_component_dropout( component_dropout_schedule, data_fraction))) return dropout_proportions def get_dropout_edit_string(dropout_schedule, data_fraction, iter_): """Return an nnet3-copy --edits line to modify raw_model_string to set dropout proportions according to dropout_proportions. E.g. if _dropout_proportions(dropout_schedule, data_fraction) returns [('*', 0.625)], this will return the string: "nnet3-copy --edits='set-dropout-proportion name=* proportion=0.625'" Arguments: dropout_schedule: Value for the --trainer.dropout-schedule option. See help for --trainer.dropout-schedule. See _self_test() for examples. See ReadEditConfig() in nnet3/nnet-utils.h to see how set-dropout-proportion directive works. """ if dropout_schedule is None: return "" dropout_proportions = _get_dropout_proportions( dropout_schedule, data_fraction) edit_config_lines = [] dropout_info = [] for component_name, dropout_proportion in dropout_proportions: edit_config_lines.append( "set-dropout-proportion name={0} proportion={1}".format( component_name, dropout_proportion)) dropout_info.append("pattern/dropout-proportion={0}/{1}".format( component_name, dropout_proportion)) if _debug_dropout: logger.info("On iteration %d, %s", iter_, ', '.join(dropout_info)) return ("""nnet3-copy --edits='{edits}' - - |""".format( edits=";".join(edit_config_lines))) def _self_test(): """Run self-test. This method is called if the module is run as a standalone script. """ def assert_approx_equal(list1, list2): """Checks that the two dropout proportions lists are equal.""" assert len(list1) == len(list2) for i in range(0, len(list1)): assert len(list1[i]) == 2 assert len(list2[i]) == 2 assert list1[i][0] == list2[i][0] assert abs(list1[i][1] - list2[i][1]) < 1e-8 assert (_parse_dropout_option('*=0.0,0.5,0.0 lstm.*=0.0,0.3@0.75,0.0') == [ ('*', [ (1.0, 0.0), (0.5, 0.5), (0.0, 0.0) ]), ('lstm.*', [ (1.0, 0.0), (0.75, 0.3), (0.0, 0.0) ]) ]) assert_approx_equal(_get_dropout_proportions( '*=0.0,0.5,0.0 lstm.*=0.0,0.3@0.75,0.0', 0.75), [ ('*', 0.25), ('lstm.*', 0.3) ]) assert_approx_equal(_get_dropout_proportions( '*=0.0,0.5,0.0 lstm.*=0.0,0.3@0.75,0.0', 0.5), [ ('*', 0.5), ('lstm.*', 0.2) ]) assert_approx_equal(_get_dropout_proportions( '*=0.0,0.5,0.0 lstm.*=0.0,0.3@0.75,0.0', 0.25), [ ('*', 0.25), ('lstm.*', 0.1) ]) assert (_parse_dropout_option('0.0,0.3,0.0') == [ ('*', [ (1.0, 0.0), (0.5, 0.3), (0.0, 0.0) ]) ]) assert_approx_equal(_get_dropout_proportions('0.0,0.3,0.0', 0.5), [ ('*', 0.3) ]) assert_approx_equal(_get_dropout_proportions('0.0,0.3,0.0', 0.0), [ ('*', 0.0) ]) assert_approx_equal(_get_dropout_proportions('0.0,0.3,0.0', 1.0), [ ('*', 0.0) ]) assert_approx_equal(_get_dropout_proportions('0.0,0.3,0.0', 0.25), [ ('*', 0.15) ]) assert (_parse_dropout_option('0.0,0.5@0.25,0.0,0.6@0.75,0.0') == [ ('*', [ (1.0, 0.0), (0.75, 0.6), (0.5, 0.0), (0.25, 0.5), (0.0, 0.0) ]) ]) assert_approx_equal(_get_dropout_proportions( '0.0,0.5@0.25,0.0,0.6@0.75,0.0', 0.25), [ ('*', 0.5) ]) assert_approx_equal(_get_dropout_proportions( '0.0,0.5@0.25,0.0,0.6@0.75,0.0', 0.1), [ ('*', 0.2) ]) assert (_parse_dropout_option('lstm.*=0.0,0.3,0.0@0.75,1.0') == [ ('lstm.*', [ (1.0, 1.0), (0.75, 0.0), (0.5, 0.3), (0.0, 0.0) ]) ]) assert_approx_equal(_get_dropout_proportions( 'lstm.*=0.0,0.3,0.0@0.75,1.0', 0.25), [ ('lstm.*', 0.15) ]) assert_approx_equal(_get_dropout_proportions( 'lstm.*=0.0,0.3,0.0@0.75,1.0', 0.5), [ ('lstm.*', 0.3) ]) assert_approx_equal(_get_dropout_proportions( 'lstm.*=0.0,0.3,0.0@0.75,1.0', 0.9), [ ('lstm.*', 0.6) ]) if __name__ == '__main__': try: _self_test() except Exception: logger.error("Failed self test") raise |