// nnet/nnet-lstm-projected-streams.h // Copyright 2016 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_NNET_NNET_RECURRENT_STREAMS_H_ #define KALDI_NNET_NNET_RECURRENT_STREAMS_H_ #include #include #include "nnet/nnet-component.h" #include "nnet/nnet-utils.h" #include "cudamatrix/cu-math.h" namespace kaldi { namespace nnet1 { /** * Component with recurrent connections, 'tanh' non-linearity. * No internal state preserved, starting each sequence from zero vector. * * Can be used in 'per-sentence' training and multi-stream training. */ class RecurrentComponent : public MultistreamComponent { public: RecurrentComponent(int32 input_dim, int32 output_dim): MultistreamComponent(input_dim, output_dim) { } ~RecurrentComponent() { } Component* Copy() const { return new RecurrentComponent(*this); } ComponentType GetType() const { return kRecurrentComponent; } void InitData(std::istream &is) { // define options, float param_scale = 0.02; // parse the line from prototype, std::string token; while (is >> std::ws, !is.eof()) { ReadToken(is, false, &token); /**/ if (token == "") ReadBasicType(is, false, &grad_clip_); else if (token == "") ReadBasicType(is, false, &diff_clip_); else if (token == "") ReadBasicType(is, false, &learn_rate_coef_); else if (token == "") ReadBasicType(is, false, &bias_learn_rate_coef_); else if (token == "") ReadBasicType(is, false, ¶m_scale); else KALDI_ERR << "Unknown token " << token << ", a typo in config?" << " (GradClip|DiffClip|LearnRateCoef|BiasLearnRateCoef|ParamScale)"; } // init the weights and biases (from uniform dist.), w_forward_.Resize(output_dim_, input_dim_); w_recurrent_.Resize(output_dim_, output_dim_); bias_.Resize(output_dim_); RandUniform(0.0, 2.0 * param_scale, &w_forward_); RandUniform(0.0, 2.0 * param_scale, &w_recurrent_); RandUniform(0.0, 2.0 * param_scale, &bias_); } void ReadData(std::istream &is, bool binary) { // Read all the '' in arbitrary order, while ('<' == Peek(is, binary)) { std::string token; int first_char = PeekToken(is, binary); switch (first_char) { case 'G': ExpectToken(is, binary, ""); ReadBasicType(is, binary, &grad_clip_); break; case 'D': ExpectToken(is, binary, ""); ReadBasicType(is, binary, &diff_clip_); break; case 'L': ExpectToken(is, binary, ""); ReadBasicType(is, binary, &learn_rate_coef_); break; case 'B': ExpectToken(is, binary, ""); ReadBasicType(is, binary, &bias_learn_rate_coef_); break; default: ReadToken(is, false, &token); KALDI_ERR << "Unknown token: " << token; } } // Read the data (data follow the tokens), w_forward_.Read(is, binary); w_recurrent_.Read(is, binary); bias_.Read(is, binary); } void WriteData(std::ostream &os, bool binary) const { WriteToken(os, binary, ""); WriteBasicType(os, binary, grad_clip_); WriteToken(os, binary, ""); WriteBasicType(os, binary, diff_clip_); WriteToken(os, binary, ""); WriteBasicType(os, binary, learn_rate_coef_); WriteToken(os, binary, ""); WriteBasicType(os, binary, bias_learn_rate_coef_); if (!binary) os << "\n"; w_forward_.Write(os, binary); w_recurrent_.Write(os, binary); bias_.Write(os, binary); } int32 NumParams() const { return w_forward_.NumRows() * w_forward_.NumCols() + w_recurrent_.NumRows() * w_recurrent_.NumCols() + bias_.Dim(); } void GetGradient(VectorBase* gradient) const { KALDI_ASSERT(gradient->Dim() == NumParams()); int32 offset, len; offset = 0; len = w_forward_corr_.NumRows() * w_forward_corr_.NumCols(); gradient->Range(offset, len).CopyRowsFromMat(w_forward_corr_); offset += len; len = w_recurrent_corr_.NumRows() * w_recurrent_corr_.NumCols(); gradient->Range(offset, len).CopyRowsFromMat(w_recurrent_corr_); offset += len; len = bias_corr_.Dim(); gradient->Range(offset, len).CopyFromVec(bias_corr_); offset += len; KALDI_ASSERT(offset == NumParams()); } void GetParams(VectorBase* params) const { KALDI_ASSERT(params->Dim() == NumParams()); int32 offset, len; offset = 0; len = w_forward_.NumRows() * w_forward_.NumCols(); params->Range(offset, len).CopyRowsFromMat(w_forward_); offset += len; len = w_recurrent_.NumRows() * w_recurrent_.NumCols(); params->Range(offset, len).CopyRowsFromMat(w_recurrent_); offset += len; len = bias_.Dim(); params->Range(offset, len).CopyFromVec(bias_); offset += len; KALDI_ASSERT(offset == NumParams()); } void SetParams(const VectorBase& params) { KALDI_ASSERT(params.Dim() == NumParams()); int32 offset, len; offset = 0; len = w_forward_.NumRows() * w_forward_.NumCols(); w_forward_.CopyRowsFromVec(params.Range(offset, len)); offset += len; len = w_recurrent_.NumRows() * w_recurrent_.NumCols(); w_recurrent_.CopyRowsFromVec(params.Range(offset, len)); offset += len; len = bias_.Dim(); bias_.CopyFromVec(params.Range(offset, len)); offset += len; KALDI_ASSERT(offset == NumParams()); } std::string Info() const { return std::string(" ") + "\n w_forward_ " + MomentStatistics(w_forward_) + "\n w_recurrent_ " + MomentStatistics(w_recurrent_) + "\n bias_ " + MomentStatistics(bias_); } std::string InfoGradient() const { return std::string("") + "( learn_rate_coef " + ToString(learn_rate_coef_) + ", bias_learn_rate_coef " + ToString(bias_learn_rate_coef_) + ", grad-clip " + ToString(grad_clip_) + ", diff-clip " + ToString(diff_clip_) + " )" + "\n Gradients:" + "\n w_forward_corr_ " + MomentStatistics(w_forward_corr_) + "\n w_recurrent_corr_ " + MomentStatistics(w_recurrent_corr_) + "\n bias_corr_ " + MomentStatistics(bias_corr_) + "\n Forward-pass:" + "\n out_ " + MomentStatistics(out_) + "\n Backward-pass:" + "\n out_diff_bptt_ " + MomentStatistics(out_diff_bptt_); } void PropagateFnc(const CuMatrixBase &in, CuMatrixBase *out) { KALDI_ASSERT(in.NumRows() % NumStreams() == 0); int32 T = in.NumRows() / NumStreams(); int32 S = NumStreams(); // Precopy bias, out->AddVecToRows(1.0, bias_, 0.0); // Apply 'forward' connections, out->AddMatMat(1.0, in, kNoTrans, w_forward_, kTrans, 1.0); // First line of 'out' w/o recurrent signal, apply 'tanh' directly, out->RowRange(0, S).Tanh(out->RowRange(0, S)); // Apply 'recurrent' connections, for (int32 t = 1; t < T; t++) { out->RowRange(t*S, S).AddMatMat(1.0, out->RowRange((t-1)*S, S), kNoTrans, w_recurrent_, kTrans, 1.0); out->RowRange(t*S, S).Tanh(out->RowRange(t*S, S)); // Zero output for padded frames, if (sequence_lengths_.size() == S) { for (int32 s = 0; s < S; s++) { if (t >= sequence_lengths_[s]) { out->Row(t*S + s).SetZero(); } } } // } out_ = (*out); // We'll need a copy for updating the recurrent weights! // We are DONE ;) } void BackpropagateFnc(const CuMatrixBase &in, const CuMatrixBase &out, const CuMatrixBase &out_diff, CuMatrixBase *in_diff) { int32 T = in.NumRows() / NumStreams(); int32 S = NumStreams(); // Apply BPTT on 'out_diff', out_diff_bptt_ = out_diff; for (int32 t = T-1; t >= 1; t--) { // buffers, CuSubMatrix d_t = out_diff_bptt_.RowRange(t*S, S); CuSubMatrix d_t1 = out_diff_bptt_.RowRange((t-1)*S, S); const CuSubMatrix y_t = out.RowRange(t*S, S); // BPTT, d_t.DiffTanh(y_t, d_t); d_t1.AddMatMat(1.0, d_t, kNoTrans, w_recurrent_, kNoTrans, 1.0); // clipping, if (diff_clip_ > 0.0) { d_t1.ApplyFloor(-diff_clip_); d_t1.ApplyCeiling(diff_clip_); } // Zero diff for padded frames, if (sequence_lengths_.size() == S) { for (int32 s = 0; s < S; s++) { if (t >= sequence_lengths_[s]) { out_diff_bptt_.Row(t*S + s).SetZero(); } } } } // Apply 'DiffTanh' on first block, CuSubMatrix d_t = out_diff_bptt_.RowRange(0, S); const CuSubMatrix y_t = out.RowRange(0, S); d_t.DiffTanh(y_t, d_t); // Transform diffs to 'in_diff', in_diff->AddMatMat(1.0, out_diff_bptt_, kNoTrans, w_forward_, kNoTrans, 0.0); // We are DONE ;) } void Update(const CuMatrixBase &input, const CuMatrixBase &diff) { int32 T = input.NumRows() / NumStreams(); int32 S = NumStreams(); // getting the learning rate, const BaseFloat lr = opts_.learn_rate; const BaseFloat mmt = opts_.momentum; if (bias_corr_.Dim() != OutputDim()) { w_forward_corr_.Resize(w_forward_.NumRows(), w_forward_.NumCols(), kSetZero); w_recurrent_corr_.Resize(w_recurrent_.NumRows(), w_recurrent_.NumCols(), kSetZero); bias_corr_.Resize(OutputDim(), kSetZero); } // getting the gradients, w_forward_corr_.AddMatMat(1.0, out_diff_bptt_, kTrans, input, kNoTrans, mmt); w_recurrent_corr_.AddMatMat(1.0, out_diff_bptt_.RowRange(S, (T-1)*S), kTrans, out_.RowRange(0, (T-1)*S), kNoTrans, mmt); bias_corr_.AddRowSumMat(1.0, out_diff_bptt_, mmt); // updating, w_forward_.AddMat(-lr * learn_rate_coef_, w_forward_corr_); w_recurrent_.AddMat(-lr * learn_rate_coef_, w_recurrent_corr_); bias_.AddVec(-lr * bias_learn_rate_coef_, bias_corr_); } private: BaseFloat grad_clip_; ///< Clipping of the update, BaseFloat diff_clip_; ///< Clipping in the BPTT loop, // trainable parameters, CuMatrix w_forward_; CuMatrix w_recurrent_; CuVector bias_; // udpate buffers, CuMatrix w_forward_corr_; CuMatrix w_recurrent_corr_; CuVector bias_corr_; // forward propagation buffer, CuMatrix out_; // back-propagate buffer, CuMatrix out_diff_bptt_; }; // class RecurrentComponent } // namespace nnet1 } // namespace kaldi #endif // KALDI_NNET_NNET_RECURRENT_STREAMS_H_