Blame view
egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py
6.31 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# Copyright 2016 Johns Hopkins University (Author: Daniel Povey) # 2016 Vimal Manohar # Apache 2.0. """ This module contains the statistics extraction and pooling layer. """ from __future__ import print_function import re from libs.nnet3.xconfig.basic_layers import XconfigLayerBase class XconfigStatsLayer(XconfigLayerBase): """This class is for parsing lines like stats-layer name=tdnn1-stats config=mean+stddev(-99:3:9:99) input=tdnn1 This adds statistics-pooling and statistics-extraction components. An example string is 'mean(-99:3:9::99)', which means, compute the mean of data within a window of -99 to +99, with distinct means computed every 9 frames (we round to get the appropriate one), and with the input extracted on multiples of 3 frames (so this will force the input to this layer to be evaluated every 3 frames). Another example string is 'mean+stddev(-99:3:9:99)', which will also cause the standard deviation to be computed. The dimension is worked out from the input. mean and stddev add a dimension of input_dim each to the output dimension. If counts is specified, an additional dimension is added to the output to store log counts. Parameters of the class, and their defaults: input='[-1]' [Descriptor giving the input of the layer.] dim=-1 [Output dimension of layer. If provided, must match the dimension computed from input] config='' [Required. Defines what stats must be computed.] """ def __init__(self, first_token, key_to_value, prev_names=None): assert first_token in ['stats-layer'] XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names) def set_default_configs(self): self.config = {'input': '[-1]', 'dim': -1, 'config': ''} def set_derived_configs(self): config_string = self.config['config'] if config_string == '': raise RuntimeError("config has to be non-empty", self.str()) m = re.search("(mean|mean\+stddev|mean\+count|mean\+stddev\+count)" "\((-?\d+):(-?\d+):(-?\d+):(-?\d+)\)", config_string) if m is None: raise RuntimeError("Invalid statistic-config string: {0}".format( config_string), self) self._output_stddev = (m.group(1) in ['mean+stddev', 'mean+stddev+count']) self._output_log_counts = (m.group(1) in ['mean+count', 'mean+stddev+count']) self._left_context = -int(m.group(2)) self._input_period = int(m.group(3)) self._stats_period = int(m.group(4)) self._right_context = int(m.group(5)) if self._output_stddev: output_dim = 2 * self.descriptors['input']['dim'] else: output_dim = self.descriptors['input']['dim'] if self._output_log_counts: output_dim = output_dim + 1 if self.config['dim'] > 0 and self.config['dim'] != output_dim: raise RuntimeError( "Invalid dim supplied {0:d} != " "actual output dim {1:d}".format( self.config['dim'], output_dim)) self.config['dim'] = output_dim def check_configs(self): if not (self._left_context >= 0 and self._right_context >= 0 and self._input_period > 0 and self._stats_period > 0 and self._left_context % self._stats_period == 0 and self._right_context % self._stats_period == 0 and self._stats_period % self._input_period == 0): raise RuntimeError( "Invalid configuration of statistics-extraction: {0}".format( self.config['config']), self) super(XconfigStatsLayer, self).check_configs() def _generate_config(self): input_desc = self.descriptors['input']['final-string'] input_dim = self.descriptors['input']['dim'] configs = [] configs.append( 'component name={name}-extraction-{lc}-{rc} ' 'type=StatisticsExtractionComponent input-dim={dim} ' 'input-period={input_period} output-period={output_period} ' 'include-variance={var} '.format( name=self.name, lc=self._left_context, rc=self._right_context, dim=input_dim, input_period=self._input_period, output_period=self._stats_period, var='true' if self._output_stddev else 'false')) configs.append( 'component-node name={name}-extraction-{lc}-{rc} ' 'component={name}-extraction-{lc}-{rc} input={input} '.format( name=self.name, lc=self._left_context, rc=self._right_context, input=input_desc)) stats_dim = 1 + input_dim * (2 if self._output_stddev else 1) configs.append( 'component name={name}-pooling-{lc}-{rc} ' 'type=StatisticsPoolingComponent input-dim={dim} ' 'input-period={input_period} left-context={lc} right-context={rc} ' 'num-log-count-features={count} output-stddevs={var} '.format( name=self.name, lc=self._left_context, rc=self._right_context, dim=stats_dim, input_period=self._stats_period, count=1 if self._output_log_counts else 0, var='true' if self._output_stddev else 'false')) configs.append( 'component-node name={name}-pooling-{lc}-{rc} ' 'component={name}-pooling-{lc}-{rc} ' 'input={name}-extraction-{lc}-{rc} '.format( name=self.name, lc=self._left_context, rc=self._right_context)) return configs def output_name(self, auxiliary_output=None): return 'Round({name}-pooling-{lc}-{rc}, {period})'.format( name=self.name, lc=self._left_context, rc=self._right_context, period=self._stats_period) def output_dim(self, auxiliary_outputs=None): return self.config['dim'] def get_full_config(self): ans = [] config_lines = self._generate_config() for line in config_lines: for config_name in ['ref', 'final']: ans.append((config_name, line)) return ans |