Blame view
src/rnnlm/rnnlm-training.h
8.44 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
// rnnlm/rnnlm-training.h // Copyright 2017 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_RNNLM_RNNLM_TRAINING_H_ #define KALDI_RNNLM_RNNLM_TRAINING_H_ #include "rnnlm/rnnlm-core-training.h" #include "rnnlm/rnnlm-embedding-training.h" #include "rnnlm/rnnlm-utils.h" #include "rnnlm/rnnlm-example-utils.h" #include "util/kaldi-semaphore.h" namespace kaldi { namespace rnnlm { /* The class RnnlmTrainer is for training an RNNLM (one individual training job, not the top-level logic about learning rate schedules, parameter averaging, and the like); it contains the most of the logic that the command-line program rnnlm-train implements. */ class RnnlmTrainer { public: /** Constructor @param [in] train_embedding True if the user wants us to train the embedding matrix @param [in] core_config Options for training the core RNNLM @param [in] embedding_config Options for training the embedding matrix (only relevant if train_embedding is true). @param [in] objective_config Options relating to the objective function used for training. @param [in] word_feature_mat Either NULL, or a pointer to a sparse word-feature matrix of dimension vocab-size by feature-dim, where vocab-size is the highest-numbered word plus one. @param [in,out] embedding_mat Pointer to the embedding matrix; this is trained if train_embedding is true, and in either case this class retains the pointer to 'embedding_mat' during its livetime. If word_feature_mat is NULL, this is the word-embedding matrix of dimension vocab-size by embedding-dim; otherwise it is the feature-embedding matrix of dimension feature-dim by by embedding-dim, and we have to multiply it by word_feature_mat to get the word embedding matrix. @param [in,out] rnnlm The RNNLM to be trained. The class will retain this pointer and modify the neural net in-place. */ RnnlmTrainer(bool train_embedding, const RnnlmCoreTrainerOptions &core_config, const RnnlmEmbeddingTrainerOptions &embedding_config, const RnnlmObjectiveOptions &objective_config, const CuSparseMatrix<BaseFloat> *word_feature_mat, CuMatrix<BaseFloat> *embedding_mat, nnet3::Nnet *rnnlm); // Train on one example. The example is provided as a pointer because we // acquire it destructively, via Swap(). void Train(RnnlmExample *minibatch); // The destructor writes out any files that we need to write out. ~RnnlmTrainer(); int32 NumMinibatchesProcessed() { return num_minibatches_processed_; } private: int32 VocabSize(); /// This function contains the actual training code, it's called from Train(); /// it trains on minibatch_previous_. void TrainInternal(); /// This function works out the word-embedding matrix for the minibatch we're /// training on (previous_minibatch_). The word-embedding matrix for this /// minibatch is a matrix of dimension current_minibatch_.vocab_size by /// embedding_mat_.NumRows(). This function sets '*word_embedding' to be a /// pointer to the embedding matrix, which will either be '&embedding_mat_' /// (in the case where there is no sampling and no sparse feature /// representation), or 'word_embedding_storage' otherwise. In the latter /// case, 'word_embedding_storage' will be resized and written to /// appropriately. void GetWordEmbedding(CuMatrix<BaseFloat> *word_embedding_storage, CuMatrix<BaseFloat> **word_embedding); /// This function trains the word-embedding matrix for the minibatch we're /// training on (in previous_minibatch_). 'embedding_deriv' is the derivative /// w.r.t. the word-embedding for this minibatch (of dimension /// previus_minibatch_.vocab_size by embedding_mat_.NumCols()). /// You can think of it as the backprop for the function 'GetWordEmbedding()'. /// @param [in] word_embedding_deriv The derivative w.r.t. the embeddings of /// just the words used in this minibatch /// (i.e. the minibatch-level word-embedding matrix, /// possibly using a subset of words). This is an input /// but this function consumes it destructively. void TrainWordEmbedding(CuMatrixBase<BaseFloat> *word_embedding_deriv); /// The backstitch version of the above function. void TrainBackstitchWordEmbedding( bool is_backstitch_step1, CuMatrixBase<BaseFloat> *word_embedding_deriv); bool train_embedding_; // true if we are training the embedding. const RnnlmCoreTrainerOptions &core_config_; const RnnlmEmbeddingTrainerOptions &embedding_config_; const RnnlmObjectiveOptions &objective_config_; // The neural net we are training (not owned here) nnet3::Nnet *rnnlm_; // Pointer to the object that trains 'rnnlm_' (owned here). RnnlmCoreTrainer *core_trainer_; // The (word or feature) embedding matrix; it's the word embedding matrix if // word_feature_mat_.NumRows() == 0, else it's the feature embedding matrix. // The dimension is (num-words or num-features) by embedding-dim. // It's owned outside this class. CuMatrix<BaseFloat> *embedding_mat_; // Pointer to the object that trains 'embedding_mat_', or NULL if we are not // training it. Owned here. RnnlmEmbeddingTrainer *embedding_trainer_; // If the --read-sparse-word-features options is provided, then // word_feature_mat_ will contain the matrix of sparse word features, of // dimension num-words by num-features. In this case, the word embedding // matrix is the product of this matrix times 'embedding_mat_'. // It's owned outside this class. const CuSparseMatrix<BaseFloat> *word_feature_mat_; // This is the transpose of word_feature_mat_, which is needed only if we // train on egs without sampling. This is only computed once, if and when // it's needed. CuSparseMatrix<BaseFloat> word_feature_mat_transpose_; int32 num_minibatches_processed_; RnnlmExample current_minibatch_; // The variables derived_ and active_words_ corresponds to group as current_minibatch_. RnnlmExampleDerived derived_; // Only if we are doing subsampling (depends on the eg), active_words_ // contains the list of active words for the minibatch 'current_minibatch_'; // it is a CUDA version of the 'active_words' output by // RenumberRnnlmExample(). Otherwise it is empty. CuArray<int32> active_words_; // Only if we are doing subsampling AND we have sparse word features // (i.e. word_feature_mat_ is nonempty), active_word_features_ contains // just the rows of word_feature_mat_ which correspond to active_words_. // This is a derived quantity computed by the background thread. CuSparseMatrix<BaseFloat> active_word_features_; // Only if we are doing subsampling AND we have sparse word features, // active_word_features_trans_ is the transpose of active_word_features_; // This is a derived quantity computed by the background thread. CuSparseMatrix<BaseFloat> active_word_features_trans_; // This value is used in backstitch training when we need to ensure // consistent dropout masks. It's set to a value derived from rand() // when the class is initialized. int32 srand_seed_; KALDI_DISALLOW_COPY_AND_ASSIGN(RnnlmTrainer); }; } // namespace rnnlm } // namespace kaldi #endif //KALDI_RNNLM_RNNLM_TRAINING_H_ |