Blame view
egs/wsj/s5/steps/nnet3/get_saturation.pl
5.9 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
#!/usr/bin/env perl # This program parses the output of nnet3-am-info or nnet3-info, # and prints out a number between zero and one that reflects # how saturated the (sigmoid and tanh) nonlinearities are, on average # over the model. # # This is based on the 'avg-deriv' (average-derivative) values printed # out for the sigmoid and tanh components. The 'saturation' of such a component # is defined as (1.0 - its avg-deriv / the maximum possible derivative of that nonlinearity), # where the denominator is 1.0 for tanh and 0.25 for sigmoid. # This component averages the saturation over all the sigmoid/tanh units in # the network. # # It parses the Info() output of components of type SigmoidComponent, # TanhComponent, and LstmNonlinearityComponent. It prints an error message to # stderr and returns with status 1 if it could not find the info for any such components # in the input stream. # Usage: nnet3-am-info 10.mdl | steps/nnet3/get_saturation.pl # or: nnet3-info 10.raw | steps/nnet3/get_saturation.pl use warnings; my $num_nonlinearities = 0; my $total_saturation = 0.0; while (<STDIN>) { if (m/type=SigmoidComponent/) { # a line like: # component name=Lstm1_f type=SigmoidComponent, dim=1280, count=5.02e+05, # value-avg=[percentiles(0,1,2,5 10,20,50,80,90 # 95,98,99,100)=(0.06,0.17,0.19,0.24 0.28,0.33,0.44,0.62,0.79 # 0.96,0.99,1.0,1.0), mean=0.482, stddev=0.198], # deriv-avg=[percentiles(0,1,2,5 10,20,50,80,90 # 95,98,99,100)=(0.0001,0.003,0.004,0.03 0.12,0.18,0.22,0.24,0.25 # 0.25,0.25,0.25,0.25), mean=0.198, stddev=0.0591] if (m/deriv-avg=[^m]+mean=([^,]+),/) { $num_nonlinearities += 1; my $this_saturation = 1.0 - ($1 / 0.25); $total_saturation += $this_saturation; } else { print STDERR "$0: could not make sense of line (no deriv-avg?): $_"; } } elsif (m/type=TanhComponent/) { if (m/deriv-avg=[^m]+mean=([^,]+),/) { $num_nonlinearities += 1; my $this_saturation = 1.0 - ($1 / 1.0); $total_saturation += $this_saturation; } else { print STDERR "$0: could not make sense of line (no deriv-avg?): $_"; } } elsif (m/type=LstmNonlinearityComponent/) { # An example of a line like this is right at the bottom of this program, it's extremely long. my $ok = 1; foreach my $sigmoid_name ( ("i_t", "f_t", "o_t") ) { if (m/${sigmoid_name}_sigmoid=[{][^}]+deriv-avg=[^}]+mean=([^,]+),/) { $num_nonlinearities += 1; my $this_saturation = 1.0 - ($1 / 0.25); $total_saturation += $this_saturation; } else { $ok = 0; } } foreach my $tanh_name ( ("c_t", "m_t") ) { if (m/${tanh_name}_tanh=[{][^}]+deriv-avg=[^}]+mean=([^,]+),/) { $num_nonlinearities += 1; my $this_saturation = 1.0 - ($1 / 1.0); $total_saturation += $this_saturation; } else { $ok = 0; } } if (! $ok) { print STDERR "Could not parse at least one of the avg-deriv values in the following info line: $_"; } } elsif (m/type=.*GruNonlinearityComponent/) { if (m/deriv-avg=[^m]+mean=([^,]+),/) { $num_nonlinearities += 1; my $this_saturation = 1.0 - ($1 / 1.0); $total_saturation += $this_saturation; } else { print STDERR "$0: could not make sense of line (no deriv-avg?): $_"; } } } if ($num_nonlinearities == 0) { print "0.0 "; exit(0); } else { my $saturation = $total_saturation / $num_nonlinearities; if ($saturation < 0.0 || $saturation > 1.0) { print STDERR "Bad saturation value: $saturation "; exit(1); } else { print "$saturation "; } } # example line with LstmNonlinearityComponent that we parse: # component name=lstm2.lstm_nonlin type=LstmNonlinearityComponent, input-dim=2560, output-dim=1024, learning-rate=0.002, max-change=0.75, cell-dim=512, w_ic-rms=0.9941, w_fc-rms=0.8901, w_oc-rms=0.9794, count=3.53e+05, i_t_sigmoid={ self-repair-lower-threshold=0.05, self-repair-scale=1e-05, self-repaired-proportion=0.0722299, value-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.04,0.08,0.09,0.12 0.17,0.25,0.46,0.76,0.87 0.91,0.96,0.96,1.0), mean=0.494, stddev=0.253], deriv-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.0007,0.03,0.04,0.06 0.09,0.12,0.19,0.23,0.24 0.25,0.25,0.25,0.25), mean=0.179, stddev=0.0595] }, f_t_sigmoid={ self-repair-lower-threshold=0.05, self-repair-scale=1e-05, self-repaired-proportion=0.0688061, value-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.06,0.11,0.13,0.17 0.22,0.30,0.51,0.70,0.82 0.90,0.96,0.98,1.0), mean=0.509, stddev=0.219], deriv-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.001,0.01,0.03,0.07 0.11,0.15,0.21,0.24,0.25 0.25,0.25,0.25,0.25), mean=0.194, stddev=0.0561] }, c_t_tanh={ self-repair-lower-threshold=0.2, self-repair-scale=1e-05, self-repaired-proportion=0.178459, value-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(-1.0,-0.98,-0.97,-0.92 -0.82,-0.65,-0.01,0.66,0.87 0.94,0.95,0.97,0.99), mean=0.00447, stddev=0.612], deriv-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.003,0.02,0.04,0.10 0.14,0.25,0.65,0.84,0.90 0.94,0.97,0.97,0.98), mean=0.58, stddev=0.281] }, o_t_sigmoid={ self-repair-lower-threshold=0.05, self-repair-scale=1e-05, self-repaired-proportion=0.0608838, value-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.02,0.07,0.09,0.12 0.17,0.25,0.52,0.77,0.86 0.90,0.94,0.96,0.99), mean=0.514, stddev=0.256], deriv-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.007,0.04,0.04,0.07 0.09,0.12,0.19,0.23,0.24 0.25,0.25,0.25,0.25), mean=0.175, stddev=0.0579] }, m_t_tanh={ self-repair-lower-threshold=0.2, self-repair-scale=1e-05, self-repaired-proportion=0.134653, value-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(-0.99,-0.95,-0.92,-0.85 -0.73,-0.51,0.02,0.48,0.73 0.86,0.96,0.98,1.0), mean=0.00581, stddev=0.522], deriv-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.002,0.03,0.04,0.13 0.26,0.41,0.75,0.93,0.97 0.99,1.0,1.0,1.0), mean=0.672, stddev=0.272] } |