Blame view

egs/wsj/s5/steps/nnet3/get_saturation.pl 5.9 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  #!/usr/bin/env perl
  
  # This program parses the output of nnet3-am-info or nnet3-info,
  # and prints out a number between zero and one that reflects
  # how saturated the (sigmoid and tanh) nonlinearities are, on average
  # over the model.
  #
  # This is based on the 'avg-deriv' (average-derivative) values printed
  # out for the sigmoid and tanh components.  The 'saturation' of such a component
  # is defined as (1.0 - its avg-deriv / the maximum possible derivative of that nonlinearity),
  # where the denominator is 1.0 for tanh and 0.25 for sigmoid.
  # This component averages the saturation over all the sigmoid/tanh units in
  # the network.
  #
  # It parses the Info() output of components of type SigmoidComponent,
  # TanhComponent, and LstmNonlinearityComponent.  It prints an error message to
  # stderr and returns with status 1 if it could not find the info for any such components
  # in the input stream.
  
  # Usage: nnet3-am-info 10.mdl | steps/nnet3/get_saturation.pl
  # or: nnet3-info 10.raw | steps/nnet3/get_saturation.pl
  
  use warnings;
  
  my $num_nonlinearities = 0;
  my $total_saturation = 0.0;
  
  while (<STDIN>) {
    if (m/type=SigmoidComponent/) {
      # a line like:
      # component name=Lstm1_f type=SigmoidComponent, dim=1280, count=5.02e+05,
      # value-avg=[percentiles(0,1,2,5 10,20,50,80,90
      # 95,98,99,100)=(0.06,0.17,0.19,0.24 0.28,0.33,0.44,0.62,0.79
      # 0.96,0.99,1.0,1.0), mean=0.482, stddev=0.198],
      # deriv-avg=[percentiles(0,1,2,5 10,20,50,80,90
      # 95,98,99,100)=(0.0001,0.003,0.004,0.03 0.12,0.18,0.22,0.24,0.25
      # 0.25,0.25,0.25,0.25), mean=0.198, stddev=0.0591]
      if (m/deriv-avg=[^m]+mean=([^,]+),/) {
        $num_nonlinearities += 1;
        my $this_saturation = 1.0 - ($1 / 0.25);
        $total_saturation += $this_saturation;
      } else {
        print STDERR "$0: could not make sense of line (no deriv-avg?): $_";
      }
    } elsif (m/type=TanhComponent/) {
      if (m/deriv-avg=[^m]+mean=([^,]+),/) {
        $num_nonlinearities += 1;
        my $this_saturation = 1.0 - ($1 / 1.0);
        $total_saturation += $this_saturation;
      } else {
        print STDERR "$0: could not make sense of line (no deriv-avg?): $_";
      }
    } elsif (m/type=LstmNonlinearityComponent/) {
      # An example of a line like this is right at the bottom of this program, it's extremely long.
      my $ok = 1;
      foreach my $sigmoid_name ( ("i_t", "f_t", "o_t") ) {
        if (m/${sigmoid_name}_sigmoid=[{][^}]+deriv-avg=[^}]+mean=([^,]+),/) {
          $num_nonlinearities += 1;
          my $this_saturation = 1.0 - ($1 / 0.25);
          $total_saturation += $this_saturation;
        } else {
          $ok = 0;
        }
      }
      foreach my $tanh_name ( ("c_t", "m_t") ) {
        if (m/${tanh_name}_tanh=[{][^}]+deriv-avg=[^}]+mean=([^,]+),/) {
          $num_nonlinearities += 1;
          my $this_saturation = 1.0 - ($1 / 1.0);
          $total_saturation += $this_saturation;
        } else {
          $ok = 0;
        }
      }
      if (! $ok) {
        print STDERR "Could not parse at least one of the avg-deriv values in the following info line: $_";
      }
    } elsif (m/type=.*GruNonlinearityComponent/) {
      if (m/deriv-avg=[^m]+mean=([^,]+),/) {
        $num_nonlinearities += 1;
        my $this_saturation = 1.0 - ($1 / 1.0);
        $total_saturation += $this_saturation;
      } else {
        print STDERR "$0: could not make sense of line (no deriv-avg?): $_";
      }
    }
  }
  
  
  if ($num_nonlinearities == 0) {
    print "0.0
  ";
    exit(0);
  } else {
    my $saturation = $total_saturation / $num_nonlinearities;
    if ($saturation < 0.0 || $saturation > 1.0) {
      print STDERR "Bad saturation value: $saturation
  ";
      exit(1);
    } else {
      print "$saturation
  ";
    }
  }
  
  
  
  
  # example line with LstmNonlinearityComponent that we parse:
  # component name=lstm2.lstm_nonlin type=LstmNonlinearityComponent, input-dim=2560, output-dim=1024, learning-rate=0.002, max-change=0.75, cell-dim=512, w_ic-rms=0.9941, w_fc-rms=0.8901, w_oc-rms=0.9794, count=3.53e+05, i_t_sigmoid={ self-repair-lower-threshold=0.05, self-repair-scale=1e-05, self-repaired-proportion=0.0722299, value-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.04,0.08,0.09,0.12 0.17,0.25,0.46,0.76,0.87 0.91,0.96,0.96,1.0), mean=0.494, stddev=0.253], deriv-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.0007,0.03,0.04,0.06 0.09,0.12,0.19,0.23,0.24 0.25,0.25,0.25,0.25), mean=0.179, stddev=0.0595] }, f_t_sigmoid={ self-repair-lower-threshold=0.05, self-repair-scale=1e-05, self-repaired-proportion=0.0688061, value-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.06,0.11,0.13,0.17 0.22,0.30,0.51,0.70,0.82 0.90,0.96,0.98,1.0), mean=0.509, stddev=0.219], deriv-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.001,0.01,0.03,0.07 0.11,0.15,0.21,0.24,0.25 0.25,0.25,0.25,0.25), mean=0.194, stddev=0.0561] }, c_t_tanh={ self-repair-lower-threshold=0.2, self-repair-scale=1e-05, self-repaired-proportion=0.178459, value-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(-1.0,-0.98,-0.97,-0.92 -0.82,-0.65,-0.01,0.66,0.87 0.94,0.95,0.97,0.99), mean=0.00447, stddev=0.612], deriv-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.003,0.02,0.04,0.10 0.14,0.25,0.65,0.84,0.90 0.94,0.97,0.97,0.98), mean=0.58, stddev=0.281] }, o_t_sigmoid={ self-repair-lower-threshold=0.05, self-repair-scale=1e-05, self-repaired-proportion=0.0608838, value-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.02,0.07,0.09,0.12 0.17,0.25,0.52,0.77,0.86 0.90,0.94,0.96,0.99), mean=0.514, stddev=0.256], deriv-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.007,0.04,0.04,0.07 0.09,0.12,0.19,0.23,0.24 0.25,0.25,0.25,0.25), mean=0.175, stddev=0.0579] }, m_t_tanh={ self-repair-lower-threshold=0.2, self-repair-scale=1e-05, self-repaired-proportion=0.134653, value-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(-0.99,-0.95,-0.92,-0.85 -0.73,-0.51,0.02,0.48,0.73 0.86,0.96,0.98,1.0), mean=0.00581, stddev=0.522], deriv-avg=[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.002,0.03,0.04,0.13 0.26,0.41,0.75,0.93,0.97 0.99,1.0,1.0,1.0), mean=0.672, stddev=0.272] }