nnet-utils.h
25.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
// nnet3/nnet-utils.h
// Copyright 2015 Johns Hopkins University (author: Daniel Povey)
// 2016 Daniel Galvez
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_NNET3_NNET_UTILS_H_
#define KALDI_NNET3_NNET_UTILS_H_
#include "base/kaldi-common.h"
#include "util/kaldi-io.h"
#include "matrix/matrix-lib.h"
#include "nnet3/nnet-common.h"
#include "nnet3/nnet-component-itf.h"
#include "nnet3/nnet-descriptor.h"
#include "nnet3/nnet-computation.h"
#include "nnet3/nnet-example.h"
namespace kaldi {
namespace nnet3 {
/// \file nnet3/nnet-utils.h
/// This file contains some miscellaneous functions dealing with class Nnet.
/// Given an nnet and a computation request, this function works out which
/// requested outputs in the computation request are computable; it outputs this
/// information as a vector "is_computable" indexed by the same indexes as
/// request.outputs.
/// It does this by executing some of the early stages of compilation.
void EvaluateComputationRequest(
const Nnet &nnet,
const ComputationRequest &request,
std::vector<std::vector<bool> > *is_computable);
/// returns the number of output nodes of this nnet.
int32 NumOutputNodes(const Nnet &nnet);
/// returns the number of input nodes of this nnet.
int32 NumInputNodes(const Nnet &nnet);
/// Calls PerturbParams (with the given stddev) on all updatable components of
/// the nnet.
void PerturbParams(BaseFloat stddev,
Nnet *nnet);
/// Returns dot product between two networks of the same structure (calls the
/// DotProduct functions of the Updatable components and sums up the return
/// values).
BaseFloat DotProduct(const Nnet &nnet1,
const Nnet &nnet2);
/// Returns dot products between two networks of the same structure (calls the
/// DotProduct functions of the Updatable components and fill in the output
/// vector).
void ComponentDotProducts(const Nnet &nnet1,
const Nnet &nnet2,
VectorBase<BaseFloat> *dot_prod);
/// This function is for printing, to a string, a vector with one element per
/// updatable component of the nnet (e.g. the output of ComponentDotProducts),
/// in a human readable way, as [ component-name1:number1
/// component-name2:number2 ... ].
std::string PrintVectorPerUpdatableComponent(const Nnet &nnet,
const VectorBase<BaseFloat> &vec);
/// This function returns true if the nnet has the following properties:
/// It has an output called "output" (other outputs are allowed but may be
/// ignored).
/// It has an input called "input", and possibly an extra input called
/// "ivector", but no other inputs.
/// There are probably some other properties that we really ought to
/// be checking, and we may add more later on.
bool IsSimpleNnet(const Nnet &nnet);
/// Zeroes the component stats in all nonlinear components in the nnet.
void ZeroComponentStats(Nnet *nnet);
/// ComputeSimpleNnetContext computes the left-context and right-context of a nnet.
/// The nnet must satisfy IsSimpleNnet(nnet).
///
/// It does this by constructing a ComputationRequest with a certain number of inputs
/// available, outputs can be computed.. It does the same after shifting the time
/// index of the output to all values 0, 1, ... n-1, where n is the output
/// of nnet.Modulus(). Then it returns the largest left context and the largest
/// right context that it infers from any of these computation requests.
void ComputeSimpleNnetContext(const Nnet &nnet,
int32 *left_context,
int32 *right_context);
/// Sets the underlying learning rate for all the components in the nnet to this
/// value. this will get multiplied by the individual learning-rate-factors to
/// produce the actual learning rates.
void SetLearningRate(BaseFloat learning_rate,
Nnet *nnet);
/// Scales the nnet parameters and stats by this scale.
void ScaleNnet(BaseFloat scale, Nnet *nnet);
/// Sets nnet as gradient by Setting is_gradient_ to true and
/// learning_rate_ to 1 for each UpdatableComponent in nnet
void SetNnetAsGradient(Nnet *nnet);
/// Does *dest += alpha * src (affects nnet parameters and
/// stored stats).
void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest);
/// Does *dest += alpha * src for updatable components (affects nnet parameters),
/// and *dest += scale * src for other components (affects stored stats).
/// Here, alphas is a vector of size equal to the number of updatable components
void AddNnetComponents(const Nnet &src, const Vector<BaseFloat> &alphas,
BaseFloat scale, Nnet *dest);
/// Returns true if 'nnet' has some kind of recurrency.
bool NnetIsRecurrent(const Nnet &nnet);
/// Returns the total of the number of parameters in the updatable components of
/// the nnet.
int32 NumParameters(const Nnet &src);
/// Copies the nnet parameters to *params, whose dimension must
/// be equal to NumParameters(src).
void VectorizeNnet(const Nnet &src,
VectorBase<BaseFloat> *params);
/// Copies the parameters from params to *dest. the dimension of params must
/// be equal to NumParameters(*dest).
void UnVectorizeNnet(const VectorBase<BaseFloat> ¶ms,
Nnet *dest);
/// Returns the number of updatable components in the nnet.
int32 NumUpdatableComponents(const Nnet &dest);
/// Controls if natural gradient will be updated
void FreezeNaturalGradient(bool freeze, Nnet *nnet);
/// Convert all components of type RepeatedAffineComponent or
/// NaturalGradientRepeatedAffineComponent to BlockAffineComponent in nnet.
void ConvertRepeatedToBlockAffine(Nnet *nnet);
/// This function returns various info about the neural net.
/// If the nnet satisfied IsSimpleNnet(nnet), the info includes "left-context=5\nright-context=3\n...". The info includes
/// the output of nnet.Info().
/// This is modeled after the info that AmNnetSimple returns in its
/// Info() function (we need this in the CTC code).
std::string NnetInfo(const Nnet &nnet);
/// This function sets the dropout proportion in all dropout components to
/// dropout_proportion value.
void SetDropoutProportion(BaseFloat dropout_proportion, Nnet *nnet);
/// Returns true if nnet has at least one component of type BatchNormComponent.
bool HasBatchnorm(const Nnet &nnet);
/// This function affects only components of type BatchNormComponent.
/// It sets "test mode" on such components (if you call it with test_mode =
/// true, otherwise it would set normal mode, but this wouldn't be needed
/// often). "test mode" means that instead of using statistics from the batch,
/// it does a deterministic normalization based on statistics stored at training
/// time.
void SetBatchnormTestMode(bool test_mode, Nnet *nnet);
/// This function zeros the stored component-level stats in the nnet using
/// ZeroComponentStats(), then recomputes them with the supplied egs. It
/// affects batch-norm, for instance. See also the version of RecomputeStats
/// declared in nnet-chain-diagnostics.h.
void RecomputeStats(const std::vector<NnetExample> &egs, Nnet *nnet);
/// This function affects components of child-classes of
/// RandomComponent.
/// It sets "test mode" on such components (if you call it with test_mode =
/// true, otherwise it would set normal mode, but this wouldn't be needed often).
/// "test mode" means that having a mask containing (1-dropout_prob) in all
/// elements.
void SetDropoutTestMode(bool test_mode, Nnet *nnet);
/**
\brief This function calls 'ResetGenerator()' on all components in 'nnet'
that inherit from class RandomComponent. It's used when you need
to ensure consistency in things like dropout masks, across subsequent
neural net evaluations. You will likely want to call srand() before calling
this.
*/
void ResetGenerators(Nnet *nnet);
/// This function finds a list of components that are never used, and outputs
/// the integer comopnent indexes (you can use these to index
/// nnet.GetComponentNames() to get their names).
void FindOrphanComponents(const Nnet &nnet, std::vector<int32> *components);
/// This function finds a list of nodes that are never used to compute any
/// output, and outputs the integer node indexes (you can use these to index
/// nnet.GetNodeNames() to get their names).
void FindOrphanNodes(const Nnet &nnet, std::vector<int32> *nodes);
/**
Config class for the CollapseModel function. This function
is reponsible for collapsing together sequential components where
doing so could make the test-time operation more efficient.
For example, dropout components and batch-norm components that
are in test mode can be combined with the next layer; and if there
are successive affine components it may also be possible to
combine these under some circumstances.
It expects batch-norm components to be in test mode; you should probably call
SetBatchnormTestMode() and SetDropoutTestMode() before CollapseModel().
*/
struct CollapseModelConfig {
bool collapse_dropout; // dropout then affine/conv.
bool collapse_batchnorm; // batchnorm then affine.
bool collapse_affine; // affine or fixed-affine then affine.
bool collapse_scale; // affine then fixed-scale.
CollapseModelConfig(): collapse_dropout(true),
collapse_batchnorm(true),
collapse_affine(true),
collapse_scale(true) { }
};
/**
This function modifies the neural net for efficiency, in a way that
suitable to be done in test time. For example, it tries to get
rid of dropout, batchnorm and fixed-scale components, and to
collapse subsequent affine components if doing so won't hurt
speed.
*/
void CollapseModel(const CollapseModelConfig &config,
Nnet *nnet);
/**
ReadEditConfig() reads a file with a similar-looking format to the config file
read by Nnet::ReadConfig(), but this consists of a sequence of operations to
perform on an existing network, mostly modifying components. It's one
"directive" (i.e. command) per line, but if supplying the options via
the --edits option to programs like nnet3-am-copy, you can use a semicolon
in place of the newline to separate commands.
The following describes the allowed commands. Note: all patterns are like
UNIX globbing patterns where the only metacharacter is '*', representing zero
or more characters.
\verbatim
convert-to-fixed-affine [name=<name-pattern>]
Converts the given affine components to FixedAffineComponent which is not updatable.
remove-orphan-nodes [remove-orphan-inputs=(true|false)]
Removes orphan nodes (that are never used to compute anything). Note:
remove-orphan-inputs defaults to false.
remove-orphan-components
Removes orphan components (those that are never used by any node).
remove-orphans [remove-orphan-inputs=(true|false)]
The same as calling remove-orphan-nodes and then remove-orphan-components.
set-learning-rate [name=<name-pattern>] learning-rate=<learning-rate>
Sets the learning rate for any updatable components matching the name pattern.
Note: this sets the 'underlying' learning rate, i.e. it will get
multiplied by any 'learning-rate-factor' set in the components.
set-learning-rate-factor [name=<name-pattern>] learning-rate-factor=<learning-rate-factor>
Sets the learning rate factor for any updatable components matching the name pattern.
rename-node old-name=<old-name> new-name=<new-name>
Renames a node; this is a surface renaming that does not affect the structure
(for structural changes, use the regular config file format, not the
edits-config). This is mostly useful for outputs, e.g. when doing
multilingual experiments.
remove-output-nodes name=<name-pattern>
Removes a subset of output nodes, those matching the pattern. You cannot
remove internal nodes directly; instead you should use the command
'remove-orphans'.
set-dropout-proportion [name=<name-pattern>] proportion=<dropout-proportion>
Sets the dropout rates for any components of type DropoutComponent,
DropoutMaskComponent or GeneralDropoutComponent whose
names match the given <name-pattern> (e.g. lstm*). <name-pattern> defaults to "*".
apply-svd name=<name-pattern> bottleneck-dim=<dim> energy-threshold=<threshold> shrinkage-threshold=<s>
Locates all components with names matching <name-pattern>, which are
type AffineComponent or child classes thereof. If <dim> is
less than the minimum of the (input or output) dimension of the component,
it does SVD on the components' parameters, retaining only the largest
<dim> singular values, replacing these components with sequences of two
components, of types LinearComponent and NaturalGradientAffineComponent.
Instead we can set the filtering criterion for the Singular values as energy-threshold,
and retain those values which contribute to energy-threshold times the total energy of
the original singular values. A particular SVD factored component is left unshrinked,
if the shrinkage ratio of the total no. of its parameters,
after the SVD based refactoring, is greater than shrinkage threshold.
See also 'reduce-rank'.
reduce-rank name=<name-pattern> rank=<dim>
Locates all components with names matching <name-pattern>, which are
type AffineComponent or child classes thereof. Does SVD on the
components' parameters, retaining only the largest <dim> singular values,
and writes the reconstructed matrix back to the component. See also
'apply-svd', which structurally breaks the component into two pieces.
\endverbatim
*/
void ReadEditConfig(std::istream &config_file, Nnet *nnet);
/**
This function does the operation '*nnet += scale * delta_nnet', while
respecting any max-parameter-change (max-param-change) specified in the
updatable components, and also the global max-param-change specified as
'max_param_change'.
With max-changes taken into account, the operation of this function is
equivalent to the following, although it's done more efficiently:
\code
Nnet temp_nnet(delta_nnet);
ScaleNnet(1.0 / max_change_scale, &temp_nnet);
[ Scale down parameters for each component of temp_nnet as needed so
their Euclidean norms do not exceed their per-component max-changes ]
[ Scale down temp_nnet as needed so its Euclidean norm does not exceed
the global max-change ]
ScaleNnet(max_change_scale, &temp_nnet); // undo the previous scaling.
AddNnet(temp_nnet, scale, nnet);
\endcode
@param [in] delta_nnet The copy of '*nnet' neural network that contains
the proposed change in parameters. Normally this will previously
have been set to: (delta_nnet =
parameter-derivative-on-current-minibatch *
learning-rate per parameter), with any natural gradient applied
as specified in the components; but this may be different if
momentum or backstitch are used.
@param [in] max_param_change The global max-param-change specified on the
command line (e.g. 2.0), which specifies the largest change
allowed to '*nnet' in Euclidean norm. If <= 0, no global
max-param-change will be enforced, but any max-change values
specified in the components will still be enforced; see
UpdatableComponent::MaxChange(), and search for 'max-change' in
the configs or nnet3-info output).
@param [in] max_change_scale This value, which will normally be 1.0, is used
to scale all per-component max-change values and the global
'max_param_change', before applying them (so we use
'max_change_scale * uc->MaxChange()' as the per-component
max-change, and 'max_change_scale * max_param_change' as the
global max-change).
@param [in] scale This value, which will normally be 1.0, is a scaling
factor used when adding to 'nnet', applied after any max-changes.
It is provided for backstitch-related purposes.
@param [in,out] nnet The nnet which we add to.
@param [out] num_max_change_per_component_applied We add to the elements of
this the count for each per-component max-change.
@param [out] num_max_change_global_applied We to this the count for the
global max-change.
*/
bool UpdateNnetWithMaxChange(const Nnet &delta_nnet,
BaseFloat max_param_change,
BaseFloat max_change_scale,
BaseFloat scale, Nnet *nnet,
std::vector<int32> *
num_max_change_per_component_applied,
int32 *num_max_change_global_applied);
struct MaxChangeStats;
// This overloaded version of UpdateNnetWithMaxChange() is a convenience
// wrapper for when you have a MaxChangeStats object to keep track
// of how many times the max-change was applied. See documentation above.
bool UpdateNnetWithMaxChange(const Nnet &delta_nnet,
BaseFloat max_param_change,
BaseFloat max_change_scale,
BaseFloat scale, Nnet *nnet,
MaxChangeStats *stats);
/**
This function is used as part of the regular training workflow, prior to
UpdateNnetWithMaxChange().
For each updatable component c in the neural net, suppose it has a
l2-regularization constant alpha set at the component level (see
UpdatableComponent::L2Regularization()), and a learning-rate
eta, then this function does (and this is not real code):
delta_nnet->c -= 2.0 * l2_regularize_scale * alpha * eta * nnet.c
The factor of -1.0 comes from the fact that we are maximizing, and we'd
add the l2 regularization term (of the form ||\theta||_2^2, i.e. squared
l2 norm) in the objective function with negative sign; the factor of 2.0
comes from the derivative of the squared parameters. The factor
'l2_regularize_scale' is provided to this function, see below for an
explanation.
Note: the way we do it is a little bit approximate, due to the interaction
with natural gradient. The issue is that the regular gradients are
multiplied by the inverse of the approximated, smoothed and factored inverse
Fisher matrix, but the l2 gradients are not. This means that what we're
optimizing is not exactly the (regular objective plus the L2 term)-- we
could view it as optimizing (regular objective plus the l2 term times the
Fisher matrix)-- with the proviso that the Fisher matrix has been scaled in
such a way that the amount of parameter change is not affected, so this is
not an issue of affecting the overall strength of l2, just an issue of the
direction-wise weighting. In effect, the l2 term will be larger, relative
to the gradient contribution, in directions where the Fisher matrix is
large. This is probably not ideal-- but it's hard to judge without
experiments. Anyway the l2 effect is small enough, and the Fisher matrix
sufficiently smoothed with the identity, that I doubt this makes much of a
difference.
@param [in] nnet The neural net that is being trained; expected
to be different from delta_nnet
@param [in] l2_regularize_scale A scale on the l2 regularization.
Usually this will be equal to the number of
distinct examples (e.g. the number of chunks of
speech-- more precisely, the number of distinct
'n' values) in the minibatch, but this is
multiplied by a configuration value
--l2-regularize-factor passed in from the command
line. The reason for making l2 proportional to
the number of elements in the minibatch is that
we add the parameter gradients over the minibatch
(we don't average), so multiplying the l2 factor by the
number of elements in the minibatch is necessary to
make the amount of l2 vs. gradient contribution stay
the same when we vary the minibatch size.
The --l2-regularize-factor option is provided so that the
calling script can correct for the effects of
parallelization via model-averaging (we'd normally set
this to 1/num-parallel-jobs).
@param [out] delta_nnet The neural net containing the parameter
updates; this is a copy of 'nnet' that is used
for purposes of momentum and applying max-change
values. This is what this code adds to.
*/
void ApplyL2Regularization(const Nnet &nnet,
BaseFloat l2_regularize_scale,
Nnet *delta_nnet);
/**
This function scales the batchorm stats of any batchnorm components
(components of type BatchNormComponent) in 'nnet' by the scale
'batchnorm_stats_scale'.
*/
void ScaleBatchnormStats(BaseFloat batchnorm_stats_scale,
Nnet *nnet);
/**
This function, to be called after processing every minibatch, is responsible
for enforcing the orthogonality constraint for any components of type
LinearComponent or inheriting from AffineComponent that have the
"orthonormal-constraint" value set to a nonzero value.
Technically what we are doing is constraining the parameter matrix M to be a
"semi-orthogonal" matrix times a constant alpha. That is: if num-rows >
num-cols, this amounts to asserting that M M^T == alpha^2 I; otherwise, that
M^T M == alpha^2 I.
If, for a particular component, orthonormal-constraint > 0.0, then that value
becomes the "alpha" mentioned above. If orthonormal-constraint == 0.0, then
nothing is done. If orthonormal-constraint < 0.0, then it's like letting alpha
"float", i.e. we try to make M closer to (any constant alpha) times a
semi-orthogonal matrix.
In order to make it efficient on GPU, it doesn't make it completely orthonormal,
it just makes it closer to being orthonormal (times the 'orthonormal_constraint'
value). Over multiple iterations this rapidly makes it almost exactly orthonormal.
See http://www.danielpovey.com/files/2018_interspeech_tdnnf.pdf
*/
void ConstrainOrthonormal(Nnet *nnet);
/**
This just calls ConsolidateMemory() on all the components of the nnet. This
is called by the training code after processing the first minibatch. On some
components this will do nothing; on some components it will reallocate
certain quantities that have been allocated during training (mostly the
contents of NaturalGradientOnline objects, and stats for NonlinearComponents)
so that they can be put into low memory. This will tend to minimize
memory fragmentation. Read comments in ../cudamatrix/cu-allocator.h for
more explanation.
*/
void ConsolidateMemory(Nnet *nnet);
/** This utility function can be used to obtain the number of distinct 'n'
values in a training example. This is the number of examples
(e.g. sequences) that have been combined into a single example. (Actually
it returns the (largest - smallest + 1) of 'n' values, and assumes they are
consecutive).
@param [in] vec The vector of NnetIo objects from the training example
(NnetExample or NnetChainExample) for which we need the
number of 'n' values
@param [in] exhaustive If true, it will check exhaustively what largest
and smallest 'n' values are. If 'false' it does it in a
fast way which will return the same answer as if
exhaustive == true for all the types of eg we currently
create (basically: correct if the last row of the input
or supervision matrices has the last-numbered 'n'
value), and will occasionally (randomly) do a test to
check that this is the same as if we called it with
'exhaustive=true'.
*/
int32 GetNumNvalues(const std::vector<NnetIo> &io_vec,
bool exhaustive);
struct MaxChangeStats {
int32 num_max_change_global_applied;
int32 num_minibatches_processed;
std::vector<int32> num_max_change_per_component_applied;
MaxChangeStats(const Nnet &nnet):
num_max_change_global_applied(0),
num_minibatches_processed(0),
num_max_change_per_component_applied(NumUpdatableComponents(nnet), 0) { }
// Prints the max-change stats. Usually will be called at the end
// of the program. The nnet is only needed for structural information,
// to work out the component names.
void Print(const Nnet &nnet) const;
};
} // namespace nnet3
} // namespace kaldi
#endif