Blame view

egs/wsj/s5/steps/nnet3/xconfig_to_configs.py 14.4 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
  #!/usr/bin/env python
  
  # Copyright 2016    Johns Hopkins University (Dan Povey)
  #           2016    Vijayaditya Peddinti
  #           2017    Google Inc. (vpeddinti@google.com)
  # Apache 2.0.
  
  # we're using python 3.x style print but want it to work in python 2.x,
  from __future__ import print_function
  import argparse
  import os
  import sys
  from collections import defaultdict
  
  sys.path.insert(0, 'steps/')
  # the following is in case we weren't running this from the normal directory.
  sys.path.insert(0, os.path.realpath(os.path.dirname(sys.argv[0])) + '/')
  
  import libs.nnet3.xconfig.parser as xparser
  import libs.common as common_lib
  
  
  def get_args():
      # we add compulsary arguments as named arguments for readability
      parser = argparse.ArgumentParser(
          description="Reads an xconfig file and creates config files "
                      "for neural net creation and training",
          epilog='Search egs/*/*/local/{nnet3,chain}/*sh for examples')
      parser.add_argument('--xconfig-file', required=True,
                          help='Filename of input xconfig file')
      parser.add_argument('--existing-model',
                          help='Filename of previously trained neural net '
                               '(e.g. final.mdl) which is useful in case of '
                               'using nodes from list of component-nodes in '
                               'already trained model '
                               'to generate new config file for new model.'
                               'The context info is also generated using '
                               'a model generated by adding final.config '
                               'to the existing model.'
                               'e.g. In Transfer learning: generate new model using '
                               'component nodes in existing model.')
      parser.add_argument('--config-dir', required=True,
                          help='Directory to write config files and variables')
      parser.add_argument('--nnet-edits', type=str, default=None,
                          action=common_lib.NullstrToNoneAction,
                          help="""This option is useful in case the network you
                          are creating does not have an output node called
                          'output' (e.g. for multilingual setups).  You can set
                          this to an edit-string like: 'rename-node old-name=xxx
                          new-name=output' if node xxx plays the role of the
                          output node in this network.  This is only used for
                          computing the left/right context.""")
  
      print(' '.join(sys.argv), file=sys.stderr)
  
      args = parser.parse_args()
      args = check_args(args)
  
      return args
  
  
  def check_args(args):
      if not os.path.exists(args.config_dir):
          os.makedirs(args.config_dir)
      return args
  
  
  def backup_xconfig_file(xconfig_file, config_dir):
      """we write a copy of the xconfig file just to have a record of the
      original input.
      """
      try:
          xconfig_file_out = open(config_dir + '/xconfig', 'w')
      except:
          raise Exception('{0}: error opening file '
                          '{1}/xconfig for output'.format(
                              sys.argv[0], config_dir))
      try:
          xconfig_file_in = open(xconfig_file)
      except:
          raise Exception('{0}: error opening file {1} for input'
                          ''.format(sys.argv[0], config_dir))
  
      print("# This file was created by the command:
  "
            "# {0}
  "
            "# It is a copy of the source from which the config files in "
            "# this directory were generated.
  ".format(' '.join(sys.argv)),
            file=xconfig_file_out)
  
      while True:
          line = xconfig_file_in.readline()
          if line == '':
              break
          print(line.strip(), file=xconfig_file_out)
      xconfig_file_out.close()
      xconfig_file_in.close()
  
  
  def write_expanded_xconfig_files(config_dir, all_layers):
      """ This functions writes config_dir/xconfig.expanded.1 and
      config_dir/xconfig.expanded.2, showing some of the internal stages of
      processing the xconfig file before turning it into config files.
      """
      try:
          xconfig_file_out = open(config_dir + '/xconfig.expanded.1', 'w')
      except:
          raise Exception('{0}: error opening file '
                          '{1}/xconfig.expanded.1 for output'.format(
                              sys.argv[0], config_dir))
  
      print('# This file was created by the command:
  '
            '# ' + ' '.join(sys.argv) + '
  '
            '#It contains the same content as ./xconfig but it was parsed and
  '
            '#default config values were set.
  '
            '# See also ./xconfig.expanded.2
  ', file=xconfig_file_out)
  
      for layer in all_layers:
          print('{}'.format(layer), file=xconfig_file_out)
      xconfig_file_out.close()
  
      try:
          xconfig_file_out = open(config_dir + '/xconfig.expanded.2', 'w')
      except:
          raise Exception('{0}: error opening file '
                          '{1}/xconfig.expanded.2 for output'.format(
                              sys.argv[0], config_dir))
  
      print('# This file was created by the command:
  '
            '# ' + ' '.join(sys.argv) + '
  '
            '# It contains the same content as ./xconfig but it was parsed,
  '
            '# default config values were set, 
  '
            '# and Descriptors (input=xxx) were normalized.
  '
            '# See also ./xconfig.expanded.1
  ',
            file=xconfig_file_out)
  
      for layer in all_layers:
          layer.normalize_descriptors()
          print('{}'.format(layer), file=xconfig_file_out)
      xconfig_file_out.close()
  
  
  def get_config_headers():
      """ This function returns a map from config-file basename
      e.g. 'init', 'ref', 'layer1' to a documentation string that goes
      at the top of the file.
      """
      # resulting dict will default to the empty string for any config files not
      # explicitly listed here.
      ans = defaultdict(str)
  
      ans['init'] = (
          '# This file was created by the command:
  '
          '# ' + ' '.join(sys.argv) + '
  '
          '# It contains the input of the network and is used in
  '
          '# accumulating stats for an LDA-like transform of the
  '
          '# input features.
  ')
      ans['ref'] = (
          '# This file was created by the command:
  '
          '# ' + ' '.join(sys.argv) + '
  '
          '# It contains the entire neural network, but with those
  '
          '# components that would normally require fixed vectors/matrices
  '
          '# read from disk, replaced with random initialization
  '
          '# (this applies to the LDA-like transform and the
  '
          '# presoftmax-prior-scale, if applicable).  This file
  '
          '# is used only to work out the left-context and right-context
  '
          '# of the network.
  ')
      ans['final'] = (
          '# This file was created by the command:
  '
          '# ' + ' '.join(sys.argv) + '
  '
          '# It contains the entire neural network.
  ')
  
      return ans
  
  
  # This is where most of the work of this program happens.
  def write_config_files(config_dir, all_layers):
      # config_basename_to_lines is map from the basename of the
      # config, as a string (i.e. 'ref', 'all', 'init') to a list of
      # strings representing lines to put in the config file.
      config_basename_to_lines = defaultdict(list)
  
      config_basename_to_header = get_config_headers()
  
      for layer in all_layers:
          try:
              pairs = layer.get_full_config()
              for config_basename, line in pairs:
                  config_basename_to_lines[config_basename].append(line)
          except Exception as e:
              print("{0}: error producing config lines from xconfig "
                    "line '{1}': error was: {2}".format(sys.argv[0],
                                                        str(layer), repr(e)),
                    file=sys.stderr)
              # we use raise rather than raise(e) as using a blank raise
              # preserves the backtrace
              raise
  
      # remove previous init.config
      try:
          os.remove(config_dir + '/init.config')
      except OSError:
          pass
  
      for basename, lines in config_basename_to_lines.items():
          # check the lines num start with 'output-node':
          num_output_node_lines = sum( [ 1 if line.startswith('output-node' ) else 0
                                         for line in lines ] )
          if num_output_node_lines == 0:
              if basename == 'init':
                  continue # do not write the init.config
              else:
                  print('{0}: error in xconfig file {1}: may be lack of a '
                        'output layer'.format(sys.argv[0], sys.argv[2]),
                                              file=sys.stderr)
                  raise
  
          header = config_basename_to_header[basename]
          filename = '{0}/{1}.config'.format(config_dir, basename)
          try:
              f = open(filename, 'w')
              print(header, file=f)
              for line in lines:
                  print(line, file=f)
              f.close()
          except Exception as e:
              print('{0}: error writing to config file {1}: error is {2}'
                    ''.format(sys.argv[0], filename, repr(e)), file=sys.stderr)
              # we use raise rather than raise(e) as using a blank raise
              # preserves the backtrace
              raise
  
  
  def add_nnet_context_info(config_dir, nnet_edits=None,
                            existing_model=None):
      """Create the 'vars' file that specifies model_left_context, etc."""
  
      common_lib.execute_command("nnet3-init {0} {1}/ref.config "
                                 "{1}/ref.raw"
                                 "".format(existing_model if
                                           existing_model is not None else "",
                                           config_dir))
      model = "{0}/ref.raw".format(config_dir)
      if nnet_edits is not None:
          model = "nnet3-copy --edits='{0}' {1} - |".format(nnet_edits,
                                                            model)
      out = common_lib.get_command_stdout('nnet3-info "{0}"'.format(model))
      # out looks like this
      # left-context: 7
      # right-context: 0
      # num-parameters: 90543902
      # modulus: 1
      # ...
      info = {}
      for line in out.split("
  ")[:4]: # take 4 initial lines,
          parts = line.split(":")
          if len(parts) != 2:
              continue
          info[parts[0].strip()] = int(parts[1].strip())
  
      # Writing the 'vars' file:
      #   model_left_context=0
      #   model_right_context=7
      vf = open('{0}/vars'.format(config_dir), 'w')
      vf.write('model_left_context={0}
  '.format(info['left-context']))
      vf.write('model_right_context={0}
  '.format(info['right-context']))
      vf.close()
  
  def check_model_contexts(config_dir, nnet_edits=None, existing_model=None):
      contexts = {}
      for file_name in ['init', 'ref']:
          if os.path.exists('{0}/{1}.config'.format(config_dir, file_name)):
              contexts[file_name] = {}
              common_lib.execute_command("nnet3-init {0} {1}/{2}.config "
                                         "{1}/{2}.raw"
                                         "".format(existing_model if
                                                   existing_model is not
                                                   None else '',
                                                   config_dir, file_name))
              model = "{0}/{1}.raw".format(config_dir, file_name)
              if nnet_edits is not None and file_name != 'init':
                  model = "nnet3-copy --edits='{0}' {1} - |".format(nnet_edits,
                                                                    model)
              out = common_lib.get_command_stdout('nnet3-info "{0}"'.format(model))
              # out looks like this
              # left-context: 7
              # right-context: 0
              # num-parameters: 90543902
              # modulus: 1
              # ...
              for line in out.split("
  ")[:4]: # take 4 initial lines,
                  parts = line.split(":")
                  if len(parts) != 2:
                      continue
                  key = parts[0].strip()
                  value = int(parts[1].strip())
                  if key in ['left-context', 'right-context']:
                      contexts[file_name][key] = value
  
      if 'init' in contexts:
          assert('ref' in contexts)
          if ('left-context' in contexts['init'] and
              'left-context' in contexts['ref']):
              if ((contexts['init']['left-context']
                   > contexts['ref']['left-context'])
                  or (contexts['init']['right-context']
                      > contexts['ref']['right-context'])):
                 raise Exception(
                      "Model specified in {0}/init.config requires greater"
                      " context than the model specified in {0}/ref.config."
                      " This might be due to use of label-delay at the output"
                      " in ref.config. Please use delay=$label_delay in the"
                      " initial fixed-affine-layer of the network, to avoid"
                      " this issue.")
  
  
  
  def main():
      args = get_args()
      backup_xconfig_file(args.xconfig_file, args.config_dir)
      existing_layers = []
      if args.existing_model is not None:
          existing_layers = xparser.get_model_component_info(args.existing_model)
      all_layers = xparser.read_xconfig_file(args.xconfig_file, existing_layers)
      write_expanded_xconfig_files(args.config_dir, all_layers)
      write_config_files(args.config_dir, all_layers)
      check_model_contexts(args.config_dir, args.nnet_edits,
                           existing_model=args.existing_model)
      add_nnet_context_info(args.config_dir, args.nnet_edits,
                            existing_model=args.existing_model)
  
  
  if __name__ == '__main__':
      main()
  
  
  # test:
  # mkdir -p foo; (echo 'input dim=40 name=input'; echo 'output name=output input=Append(-1,0,1)')  >xconfig; ./xconfig_to_configs.py xconfig foo
  #  mkdir -p foo; (echo 'input dim=40 name=input'; echo 'output-layer name=output dim=1924 input=Append(-1,0,1)')  >xconfig; ./xconfig_to_configs.py xconfig foo
  
  # mkdir -p foo; (echo 'input dim=40 name=input'; echo 'relu-renorm-layer name=affine1 dim=1024'; echo 'output-layer name=output dim=1924 input=Append(-1,0,1)')  >xconfig; ./xconfig_to_configs.py xconfig foo
  
  # mkdir -p foo; (echo 'input dim=100 name=ivector'; echo 'input dim=40 name=input'; echo 'fixed-affine-layer name=lda input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=foo/bar/lda.mat'; echo 'output-layer name=output dim=1924 input=Append(-1,0,1)')  >xconfig; ./xconfig_to_configs.py xconfig foo