rnnlm-embedding-training.h
10.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
// rnnlm/rnnlm-embedding-training.h
// Copyright 2017 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_RNNLM_RNNLM_EMBEDDING_TRAINING_H_
#define KALDI_RNNLM_RNNLM_EMBEDDING_TRAINING_H_
#include "base/kaldi-common.h"
#include "matrix/matrix-lib.h"
#include "rnnlm/rnnlm-example.h"
#include "nnet3/natural-gradient-online.h"
#include "rnnlm/rnnlm-example-utils.h"
namespace kaldi {
namespace rnnlm {
// These are options relating to the core RNNLM training,
// i.e. training the actual neural net that is for the RNNLM
// (when the word embeddings are given). This is analogous
// to NnetTrainerOptions in ../nnet-training.h, except that
// with the RNNLM the training code a few things are different,
// so we're using a totally separate training class.
// We'll add options as we need them.
struct RnnlmEmbeddingTrainerOptions {
int32 print_interval;
BaseFloat momentum;
BaseFloat max_param_change;
BaseFloat l2_regularize;
BaseFloat learning_rate; // Note: don't set the learning rate to 0.0 if you
// don't want to train this; instead, you can turn
// off training of the embedding matrix by
// controlling the command line options to the
// training program (e.g. not providing a place to
// write the embedding matrix).
BaseFloat backstitch_training_scale;
int32 backstitch_training_interval;
// Natural-gradient related options
bool use_natural_gradient;
BaseFloat natural_gradient_alpha;
int32 natural_gradient_rank;
int32 natural_gradient_update_period;
int32 natural_gradient_num_minibatches_history;
RnnlmEmbeddingTrainerOptions():
print_interval(100),
momentum(0.0),
max_param_change(1.0),
l2_regularize(0.0),
learning_rate(0.01),
backstitch_training_scale(0.0),
backstitch_training_interval(1),
use_natural_gradient(true),
natural_gradient_alpha(4.0),
natural_gradient_rank(80),
natural_gradient_update_period(4),
natural_gradient_num_minibatches_history(10) { }
void Register(OptionsItf *opts) {
opts->Register("momentum", &momentum, "Momentum constant to apply during "
"training of embedding (e.g. 0.5 or 0.9). Note: we "
"automatically multiply the learning rate by (1-momenum) "
"so that the 'effective' learning rate is the same as "
"before (because momentum would normally increase the "
"effective learning rate by 1/(1-momentum))");
opts->Register("max-param-change", &max_param_change, "The maximum change in "
"parameters allowed per minibatch, measured in Euclidean norm, "
"for the embedding matrix (the matrix of num-features by "
"embedding-dim -- or num-words by embedding-dim, if we're not "
"using a feature-based representation.");
opts->Register("l2-regularize", &l2_regularize, "L2 regularize value that "
"affects the strength of l2 regularization on embedding "
"parameters.");
opts->Register("learning-rate", &learning_rate, "The learning rate used in "
"training the word-embedding matrix.");
opts->Register("backstitch-training-scale", &backstitch_training_scale,
"backstitch training factor. "
"if 0 then in the normal training mode. It is referred to as "
"'\\alpha' in our publications.");
opts->Register("backstitch-training-interval",
&backstitch_training_interval,
"do backstitch training with the specified interval of "
"minibatches. It is referred to as 'n' in our publications.");
opts->Register("use-natural-gradient", &use_natural_gradient,
"True if you want to use natural gradient to update the "
"embedding matrix");
opts->Register("natural-gradient-alpha", &natural_gradient_alpha,
"Smoothing constant alpha to use for natural gradient when "
"updating the embedding matrix");
opts->Register("natural-gradient-rank", &natural_gradient_rank,
"Rank of the Fisher matrix in natural gradient as applied to "
"learning the embedding matrix (this is in the embedding "
"space, so the rank should probably be less than the "
"embedding dimension");
opts->Register("natural-gradient-update-period",
&natural_gradient_update_period,
"Determines how often the Fisher matrix is updated for natural "
"gradient as applied to the embedding matrix");
opts->Register("natural-gradient-num-minibatches-history",
&natural_gradient_num_minibatches_history,
"Determines how quickly the Fisher estimate for the natural gradient "
"is updated, when training the word embedding.");
}
void Check() const;
};
/** This class is responsible for training the word embedding matrix or
feature embedding matrix.
*/
class RnnlmEmbeddingTrainer {
public:
/** Constructor.
@param [in] config Structure that holds configuration options;
this class will keep a reference to it.
@param [in] embedding_mat The embedding matrix to be trained,
of dimension (num-words or num-features) by
embedding-dim (depending whether we are using a
feature representation of words, or not). This class
keeps the pointer and will modify that variable.
*/
RnnlmEmbeddingTrainer(const RnnlmEmbeddingTrainerOptions &config,
CuMatrix<BaseFloat> *embedding_mat);
/* Train on one minibatch-- this version is used either when there is no
subsampling, or when there is subsampling but we are using a feature
representation so the subsampling is handled outside of this code.
@param [in] embedding_deriv The derivative w.r.t. the (word or feature)
embedding matrix; it's provided as a non-const pointer for
convenience so that we can modify it in-place if needed
for the natural gradient update.
*/
void Train(CuMatrixBase<BaseFloat> *embedding_deriv);
// The backstitch version of the above function. Depending
// on whether is_backstitch_step1 is true, It could be either the first
// (backward) step, or the second (forward) step of backstitch.
void TrainBackstitch(bool is_backstitch_step1,
CuMatrixBase<BaseFloat> *embedding_deriv);
/* Train on one minibatch-- this version is for when there is subsampling, and
the user is providing the derivative w.r.t. just the word-indexes that were
used in this minibatch. 'active_words' is a sorted, unique list of the
word-indexes that were used in this minibatch, and 'word_embedding_deriv'
is the derivative w.r.t. the embedding of that list of words.
@param [in] active_words A sorted, unique list of the word indexes
used, with Dim() equal to word_embedding_deriv->NumRows();
contains indexes 0 <= i < embedding_deriv_->NumRows().
@param [in] word_embedding_deriv The derivative w.r.t. the
word embedding matrix; it's provided as a non-const
pointer for convenience so that we can modify
it in-place if needed for the natural gradient
update.
*/
void Train(const CuArrayBase<int32> &active_words,
CuMatrixBase<BaseFloat> *word_embedding_deriv);
// The backstitch version of the above function.
void TrainBackstitch(bool is_backstitch_step1,
const CuArrayBase<int32> &active_words,
CuMatrixBase<BaseFloat> *word_embedding_deriv);
~RnnlmEmbeddingTrainer();
private:
// Sets options in the object 'preconditioner_', based on the config
// (but not SetNumSamplesHistory(), we do that in the Train() functions because
/// we don't have the right information at this point).
void SetNaturalGradientOptions();
// Called from the destructor, this prints some stats about how often the
// max-change constraint was applied, how much data we trained on, and how
// much the parameters changed during the lifetime of this object.
// TODO: implement this.
void PrintStats();
const RnnlmEmbeddingTrainerOptions &config_;
// Object that takes care of the natural-gradient update (this is in the
// dimension of space equal to the embedding dim, which is the num-cols
// of embedding_mat_.
nnet3::OnlineNaturalGradient preconditioner_;
// The matrix we are updating
CuMatrix<BaseFloat> *embedding_mat_;
// If momentum is to be used, this is sized to the same size as
// *embedding_mat*, and used for the decaying sum of deltas.
CuMatrix<BaseFloat> embedding_mat_momentum_;
// This is a copy of the 'embedding_mat' that we were initialized with,
// which we keep around for purposes of printing stats at the end about how
// much the matrix changed; we keep it in CPU memory in case GPU memory is a
// limiting factor.
Matrix<BaseFloat> initial_embedding_mat_;
// A count of the number of times we have updated the matrix.
int32 num_minibatches_;
// A count of the number of times the max-change constraint was applied.
int32 max_change_count_;
};
} // namespace rnnlm
} // namespace kaldi
#endif //KALDI_RNNLM_RNNLM_EMBEDDING_TRAINING_H_