nnet-example.h
6.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
// nnet3/nnet-example.h
// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
// 2014 Vimal Manohar
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_NNET3_NNET_EXAMPLE_H_
#define KALDI_NNET3_NNET_EXAMPLE_H_
#include "nnet3/nnet-nnet.h"
#include "hmm/posterior.h"
#include "util/table-types.h"
#include "hmm/posterior.h"
namespace kaldi {
namespace nnet3 {
struct NnetIo {
/// the name of the input in the neural net; in simple setups it
/// will just be "input".
std::string name;
/// "indexes" is a vector the same length as features.NumRows(), explaining
/// the meaning of each row of the "features" matrix. Note: the "n" values
/// in the indexes will always be zero in individual examples, but in general
/// nonzero after we aggregate the examples into the minibatch level.
std::vector<Index> indexes;
/// The features or labels. GeneralMatrix may contain either a CompressedMatrix,
/// a Matrix, or SparseMatrix (a SparseMatrix would be the natural format for posteriors).
GeneralMatrix features;
/// This constructor creates NnetIo with name "name", indexes with n=0, x=0,
/// and t values ranging from t_begin to
/// (t_begin + t_stride * feats.NumRows() - 1) with a stride t_stride, and
/// the provided features. t_begin should be the frame that feats.Row(0)
/// represents.
NnetIo(const std::string &name,
int32 t_begin, const MatrixBase<BaseFloat> &feats,
int32 t_stride = 1);
/// This constructor creates NnetIo with name "name", indexes with n=0, x=0,
/// and t values ranging from t_begin to
/// (t_begin + t_stride * feats.NumRows() - 1) with a stride t_stride, and
/// the provided features. t_begin should be the frame that the first row
/// of 'feats' represents.
NnetIo(const std::string &name,
int32 t_begin, const GeneralMatrix &feats,
int32 t_stride = 1);
/// This constructor sets "name" to the provided string, sets "indexes" with
/// n=0, x=0, and t from t_begin to (t_begin + t_stride * labels.size() - 1)
/// with a stride t_stride, and the labels
/// as provided. t_begin should be the frame to which labels[0] corresponds.
NnetIo(const std::string &name,
int32 dim,
int32 t_begin,
const Posterior &labels,
int32 t_stride = 1);
void Swap(NnetIo *other);
NnetIo() { }
// Use default copy constructor and assignment operators.
void Write(std::ostream &os, bool binary) const;
void Read(std::istream &is, bool binary);
// this comparison is not very efficient, especially for sparse supervision.
// It's only used in testing code.
bool operator == (const NnetIo &other) const;
};
/// This hashing object hashes just the structural aspects of the NnetIo object
/// (name, indexes, feature dimension) without looking at the value of features.
/// It will be used in combining egs into batches of all similar structure.
struct NnetIoStructureHasher {
size_t operator () (const NnetIo &a) const noexcept;
};
/// This comparison object compares just the structural aspects of the NnetIo
/// object (name, indexes, feature dimension) without looking at the value of
/// features. It will be used in combining egs into batches of all similar
/// structure.
struct NnetIoStructureCompare {
bool operator () (const NnetIo &a,
const NnetIo &b) const;
};
/// NnetExample is the input data and corresponding label (or labels) for one or
/// more frames of input, used for standard cross-entropy training of neural
/// nets (and possibly for other objective functions).
struct NnetExample {
/// "io" contains the input and output. In principle there can be multiple
/// types of both input and output, with different names. The order is
/// irrelevant.
std::vector<NnetIo> io;
void Write(std::ostream &os, bool binary) const;
void Read(std::istream &is, bool binary);
NnetExample() { }
NnetExample(const NnetExample &other): io(other.io) { }
void Swap(NnetExample *other) { io.swap(other->io); }
/// Compresses any (input) features that are not sparse.
void Compress();
/// Caution: this operator == is not very efficient. It's only used in
/// testing code.
bool operator == (const NnetExample &other) const { return io == other.io; }
};
/// This hashing object hashes just the structural aspects of the NnetExample
/// without looking at the value of the features. It will be used in combining
/// egs into batches of all similar structure. Note: the hash value is
/// sensitive to the order in which the NnetIo elements (input and outputs)
/// appear, even though the merging is capable of dealing with
/// differently-ordered inputs and outputs (e.g. "input" appearing before
/// vs. after "ivector" or "output"). We don't think anyone would ever have to
/// deal with differently-ordered, but otherwise identical, egs in practice so
/// we don't bother making the hashing function independent of this order.
struct NnetExampleStructureHasher {
size_t operator () (const NnetExample &eg) const noexcept;
// We also provide a version of this that works from pointers.
size_t operator () (const NnetExample *eg) const noexcept {
return (*this)(*eg);
}
};
/// This comparator object compares just the structural aspects of the
/// NnetExample without looking at the value of the features. Like
/// NnetExampleStructureHasher, it is sensitive to the order in which the
/// differently-named NnetIo elements appear. This hashing object will be used
/// in combining egs into batches of all similar structure.
struct NnetExampleStructureCompare {
bool operator () (const NnetExample &a,
const NnetExample &b) const;
// We also provide a version of this that works from pointers.
bool operator () (const NnetExample *a,
const NnetExample *b) const { return (*this)(*a, *b); }
};
typedef TableWriter<KaldiObjectHolder<NnetExample > > NnetExampleWriter;
typedef SequentialTableReader<KaldiObjectHolder<NnetExample > > SequentialNnetExampleReader;
typedef RandomAccessTableReader<KaldiObjectHolder<NnetExample > > RandomAccessNnetExampleReader;
} // namespace nnet3
} // namespace kaldi
#endif // KALDI_NNET3_NNET_EXAMPLE_H_