nnet-fix.cc
4.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
// nnet2/nnet-fix.cc
// Copyright 2012 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "nnet2/nnet-fix.h"
namespace kaldi {
namespace nnet2 {
/* See the header for what we're doing.
The pattern we're looking for is AffineComponent followed by
a NonlinearComponent of type SigmoidComponent or TanhComponent.
*/
void FixNnet(const NnetFixConfig &config, Nnet *nnet) {
for (int32 c = 0; c + 1 < nnet->NumComponents(); c++) {
AffineComponent *ac = dynamic_cast<AffineComponent*>(
&(nnet->GetComponent(c)));
NonlinearComponent *nc = dynamic_cast<NonlinearComponent*>(
&(nnet->GetComponent(c + 1)));
if (ac == NULL || nc == NULL) continue;
// We only want to process this if it's of type SigmoidComponent
// or TanhComponent.
BaseFloat max_deriv; // The maximum derivative of this nonlinearity.
bool is_relu = false;
{
SigmoidComponent *sc = dynamic_cast<SigmoidComponent*>(nc);
TanhComponent *tc = dynamic_cast<TanhComponent*>(nc);
RectifiedLinearComponent *rc = dynamic_cast<RectifiedLinearComponent*>(nc);
if (sc != NULL) max_deriv = 0.25;
else if (tc != NULL) max_deriv = 1.0;
else if (rc != NULL) { max_deriv = 1.0; is_relu = true; }
else continue; // E.g. SoftmaxComponent; we don't handle this.
}
double count = nc->Count();
Vector<double> deriv_sum (nc->DerivSum());
if (count == 0.0 || deriv_sum.Dim() == 0) {
KALDI_WARN << "Cannot fix neural net because no statistics are stored.";
continue;
}
Vector<BaseFloat> bias_params(ac->BiasParams());
Matrix<BaseFloat> linear_params(ac->LinearParams());
int32 dim = nc->InputDim(), num_small_deriv = 0, num_large_deriv = 0;
for (int32 d = 0; d < dim; d++) {
// deriv ratio is the ratio of the computed average derivative to the
// maximum derivative of that nonlinear function.
BaseFloat deriv_ratio = deriv_sum(d) / (count * max_deriv);
KALDI_ASSERT(deriv_ratio >= 0.0 && deriv_ratio < 1.01); // Or there is an
// error in the
// math.
if (deriv_ratio < config.min_average_deriv) {
// derivative is too small, meaning we've gone off into the "flat part"
// of the sigmoid (or for ReLU, we're always-off).
if (is_relu) {
bias_params(d) += config.relu_bias_change;
} else {
BaseFloat parameter_factor = std::min(config.min_average_deriv /
deriv_ratio,
config.parameter_factor);
// we need to reduce the parameters, so multiply by 1/parameter factor.
bias_params(d) *= 1.0 / parameter_factor;
linear_params.Row(d).Scale(1.0 / parameter_factor);
}
num_small_deriv++;
} else if (deriv_ratio > config.max_average_deriv) {
// derivative is too large, meaning we're only in the linear part of the
// sigmoid, in the middle. (or for ReLU, we're always-on.
if (is_relu) {
bias_params(d) -= config.relu_bias_change;
} else {
BaseFloat parameter_factor = std::min(deriv_ratio / config.max_average_deriv,
config.parameter_factor);
// we need to increase the factors, so multiply by parameter_factor.
bias_params(d) *= parameter_factor;
linear_params.Row(d).Scale(parameter_factor);
}
num_large_deriv++;
}
}
if (is_relu) {
KALDI_LOG << "For layer " << c << " (ReLU units), increased bias for "
<< num_small_deriv << " indexes and decreased it for "
<< num_large_deriv << ", out of a total of " << dim;
} else {
KALDI_LOG << "For layer " << c << ", decreased parameters for "
<< num_small_deriv << " indexes, and increased them for "
<< num_large_deriv << " out of a total of " << dim;
}
ac->SetParams(bias_params, linear_params);
}
}
} // namespace nnet2
} // namespace kaldi