nnet-test-utils.h
4.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// nnet3/nnet-test-utils.h
// Copyright 2015 Johns Hopkins University (author: Daniel Povey)
// Copyright 2016 Daniel Galvez
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_NNET3_NNET_TEST_UTILS_H_
#define KALDI_NNET3_NNET_TEST_UTILS_H_
#include "nnet3/nnet-nnet.h"
#include "nnet3/nnet-utils.h"
#include "nnet3/nnet-example.h"
namespace kaldi {
namespace nnet3 {
/** @file
This file contains various routines that are useful in test code.
*/
struct NnetGenerationOptions {
bool allow_context;
bool allow_nonlinearity;
bool allow_recursion;
bool allow_clockwork;
bool allow_multiple_inputs;
bool allow_multiple_outputs;
bool allow_final_nonlinearity;
bool allow_use_of_x_dim;
bool allow_ivector;
bool allow_statistics_pooling;
// if set to a value >0, the output-dim of the network
// will be set to this value.
int32 output_dim;
NnetGenerationOptions():
allow_context(true),
allow_nonlinearity(true),
allow_recursion(true),
allow_clockwork(true),
allow_multiple_inputs(true),
allow_multiple_outputs(false),
allow_final_nonlinearity(true),
allow_use_of_x_dim(true),
allow_ivector(false),
allow_statistics_pooling(true),
output_dim(-1) { }
};
/** Generates a sequence of at least one config files, output as strings, where
the first in the sequence is the initial nnet, and the remaining ones may do
things like add layers. */
void GenerateConfigSequence(const NnetGenerationOptions &opts,
std::vector<std::string> *configs);
/// Generate a config string with a composite component composed only
/// of block affine, repeated affine, and natural gradient repeated affine
/// components.
void GenerateConfigSequenceCompositeBlock(const NnetGenerationOptions &opts,
std::vector<std::string> *configs);
/** This function computes an example computation request, for testing purposes.
The "Simple" in the name means that it currently only supports neural nets
that satisfy IsSimple(nnet) (defined in nnet-utils.h).
If there are 2 inputs, the "input" will be first, followed by "ivector".
In order to expand the range of things you can test with this (mainly to
stop crashes with statistics-pooling/statistics-extraction components),
this function always generates computation-requests where at least 3
successive frames of input are requested.
*/
void ComputeExampleComputationRequestSimple(
const Nnet &nnet,
ComputationRequest *request,
std::vector<Matrix<BaseFloat> > *inputs);
Component *GenerateRandomSimpleComponent();
/** Used for testing that the updatable parameters in two networks are the same.
May crash if structure differs. Prints warning and returns false if
parameters differ. E.g. set threshold to 1.0e-05 (it's a relative
threshold, applied per component). */
bool NnetParametersAreIdentical(const Nnet &nnet1,
const Nnet &nnet2,
BaseFloat threshold);
/** Low-level function that generates an nnet training example. By "simple" we
mean there is one output named "output", an input named "input", and
possibly also an input named "ivector" (this will be assumed absent if
ivector_dim <= 0). This function generates exactly "left_context" or
"right_context" frames of context on the left and right respectively. */
void GenerateSimpleNnetTrainingExample(
int32 num_supervised_frames,
int32 left_context,
int32 right_context,
int32 input_dim,
int32 output_dim,
int32 ivector_dim,
NnetExample *example);
/// Returns true if the examples are approximately equal (only intended to be
/// used in testing).
bool ExampleApproxEqual(const NnetExample &eg1,
const NnetExample &eg2,
BaseFloat delta);
} // namespace nnet3
} // namespace kaldi
#endif