Blame view

egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py 6.31 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
  # Copyright 2016    Johns Hopkins University (Author: Daniel Povey)
  #           2016    Vimal Manohar
  # Apache 2.0.
  
  """ This module contains the statistics extraction and pooling layer.
  """
  
  from __future__ import print_function
  import re
  from libs.nnet3.xconfig.basic_layers import XconfigLayerBase
  
  
  class XconfigStatsLayer(XconfigLayerBase):
      """This class is for parsing lines like
      stats-layer name=tdnn1-stats config=mean+stddev(-99:3:9:99) input=tdnn1
  
      This adds statistics-pooling and statistics-extraction components.  An
      example string is 'mean(-99:3:9::99)', which means, compute the mean of
      data within a window of -99 to +99, with distinct means computed every 9
      frames (we round to get the appropriate one), and with the input extracted
      on multiples of 3 frames (so this will force the input to this layer to be
      evaluated every 3 frames).  Another example string is
      'mean+stddev(-99:3:9:99)', which will also cause the standard deviation to
      be computed.
  
      The dimension is worked out from the input. mean and stddev add a
      dimension of input_dim each to the output dimension. If counts is
      specified, an additional dimension is added to the output to store log
      counts.
  
      Parameters of the class, and their defaults:
          input='[-1]'    [Descriptor giving the input of the layer.]
          dim=-1      [Output dimension of layer. If provided, must match the
                       dimension computed from input]
          config=''   [Required. Defines what stats must be computed.]
      """
      def __init__(self, first_token, key_to_value, prev_names=None):
          assert first_token in ['stats-layer']
          XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names)
  
      def set_default_configs(self):
          self.config = {'input': '[-1]',
                         'dim': -1,
                         'config': ''}
  
      def set_derived_configs(self):
          config_string = self.config['config']
          if config_string == '':
              raise RuntimeError("config has to be non-empty",
                                  self.str())
          m = re.search("(mean|mean\+stddev|mean\+count|mean\+stddev\+count)"
                        "\((-?\d+):(-?\d+):(-?\d+):(-?\d+)\)",
                        config_string)
          if m is None:
              raise RuntimeError("Invalid statistic-config string: {0}".format(
                  config_string), self)
  
          self._output_stddev = (m.group(1) in ['mean+stddev',
                                                'mean+stddev+count'])
          self._output_log_counts = (m.group(1) in ['mean+count',
                                                    'mean+stddev+count'])
          self._left_context = -int(m.group(2))
          self._input_period = int(m.group(3))
          self._stats_period = int(m.group(4))
          self._right_context = int(m.group(5))
  
          if self._output_stddev:
            output_dim = 2 * self.descriptors['input']['dim']
          else:
            output_dim = self.descriptors['input']['dim']
          if self._output_log_counts:
            output_dim = output_dim + 1
  
          if self.config['dim'] > 0 and self.config['dim'] != output_dim:
              raise RuntimeError(
                  "Invalid dim supplied {0:d} != "
                  "actual output dim {1:d}".format(
                      self.config['dim'], output_dim))
          self.config['dim'] = output_dim
  
      def check_configs(self):
          if not (self._left_context >= 0 and self._right_context >= 0
                  and self._input_period > 0 and self._stats_period > 0
                  and self._left_context % self._stats_period == 0
                  and self._right_context % self._stats_period == 0
                  and self._stats_period % self._input_period == 0):
              raise RuntimeError(
                  "Invalid configuration of statistics-extraction: {0}".format(
                      self.config['config']), self)
          super(XconfigStatsLayer, self).check_configs()
  
      def _generate_config(self):
          input_desc = self.descriptors['input']['final-string']
          input_dim = self.descriptors['input']['dim']
  
          configs = []
          configs.append(
              'component name={name}-extraction-{lc}-{rc} '
              'type=StatisticsExtractionComponent input-dim={dim} '
              'input-period={input_period} output-period={output_period} '
              'include-variance={var} '.format(
                  name=self.name, lc=self._left_context, rc=self._right_context,
                  dim=input_dim, input_period=self._input_period,
                  output_period=self._stats_period,
                  var='true' if self._output_stddev else 'false'))
          configs.append(
              'component-node name={name}-extraction-{lc}-{rc} '
              'component={name}-extraction-{lc}-{rc} input={input} '.format(
                  name=self.name, lc=self._left_context, rc=self._right_context,
                  input=input_desc))
  
          stats_dim = 1 + input_dim * (2 if self._output_stddev else 1)
          configs.append(
              'component name={name}-pooling-{lc}-{rc} '
              'type=StatisticsPoolingComponent input-dim={dim} '
              'input-period={input_period} left-context={lc} right-context={rc} '
              'num-log-count-features={count} output-stddevs={var} '.format(
                  name=self.name, lc=self._left_context, rc=self._right_context,
                  dim=stats_dim, input_period=self._stats_period,
                  count=1 if self._output_log_counts else 0,
                  var='true' if self._output_stddev else 'false'))
          configs.append(
              'component-node name={name}-pooling-{lc}-{rc} '
              'component={name}-pooling-{lc}-{rc} '
              'input={name}-extraction-{lc}-{rc} '.format(
                  name=self.name, lc=self._left_context, rc=self._right_context))
          return configs
  
      def output_name(self, auxiliary_output=None):
          return 'Round({name}-pooling-{lc}-{rc}, {period})'.format(
              name=self.name, lc=self._left_context,
              rc=self._right_context, period=self._stats_period)
  
      def output_dim(self, auxiliary_outputs=None):
          return self.config['dim']
  
      def get_full_config(self):
          ans = []
          config_lines = self._generate_config()
  
          for line in config_lines:
              for config_name in ['ref', 'final']:
                  ans.append((config_name, line))
  
          return ans