// nnet/nnet-sentence-averaging-component.h // Copyright 2013-2016 Brno University of Technology (Author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_NNET_NNET_SENTENCE_AVERAGING_COMPONENT_H_ #define KALDI_NNET_NNET_SENTENCE_AVERAGING_COMPONENT_H_ #include #include "nnet/nnet-component.h" #include "nnet/nnet-utils.h" #include "cudamatrix/cu-math.h" namespace kaldi { namespace nnet1 { /** * SimpleSentenceAveragingComponent does not have nested network, * it is intended to be used inside of a . * For training use 'nnet-train-perutt'. * * The sentence-averaging typically leads to small gradients, so we boost it 100x * by default (boost = multiply, it's equivalent to applying learning-rate factor). */ class SimpleSentenceAveragingComponent : public Component { public: SimpleSentenceAveragingComponent(int32 dim_in, int32 dim_out): Component(dim_in, dim_out), gradient_boost_(100.0), shrinkage_(0.0), only_summing_(false) { } ~SimpleSentenceAveragingComponent() { } Component* Copy() const { return new SimpleSentenceAveragingComponent(*this); } ComponentType GetType() const { return kSimpleSentenceAveragingComponent; } void InitData(std::istream &is) { // parse config std::string token; while (is >> std::ws, !is.eof()) { ReadToken(is, false, &token); if (token == "") ReadBasicType(is, false, &gradient_boost_); else if (token == "") ReadBasicType(is, false, &shrinkage_); else if (token == "") ReadBasicType(is, false, &only_summing_); else KALDI_ERR << "Unknown token " << token << ", a typo in config?" << " (GradientBoost|Shrinkage|OnlySumming)"; } } void ReadData(std::istream &is, bool binary) { bool end_loop = false; while (!end_loop && '<' == Peek(is, binary)) { int first_char = PeekToken(is, binary); switch (first_char) { case 'G': ExpectToken(is, binary, ""); ReadBasicType(is, binary, &gradient_boost_); break; case 'S': ExpectToken(is, binary, ""); ReadBasicType(is, binary, &shrinkage_); break; case 'O': ExpectToken(is, binary, ""); // compatibility trick, // in some models 'only_summing_' was float '0.0', // from now 'only_summing_' is 'bool': try { ReadBasicType(is, binary, &only_summing_); } catch(const std::exception &e) { KALDI_WARN << "ERROR was handled by exception!"; BaseFloat dummy_float; ReadBasicType(is, binary, &dummy_float); } break; case '!': ExpectToken(is, binary, ""); default: end_loop = true; } } } void WriteData(std::ostream &os, bool binary) const { WriteToken(os, binary, ""); WriteBasicType(os, binary, gradient_boost_); WriteToken(os, binary, ""); WriteBasicType(os, binary, shrinkage_); WriteToken(os, binary, ""); WriteBasicType(os, binary, only_summing_); } std::string Info() const { return std::string("\n gradient-boost ") + ToString(gradient_boost_) + ", shrinkage: " + ToString(shrinkage_) + ", only summing: " + ToString(only_summing_); } std::string InfoGradient() const { return Info(); } void PropagateFnc(const CuMatrixBase &in, CuMatrixBase *out) { // get the average row-vector, average_row_.Resize(InputDim()); if (only_summing_) { average_row_.AddRowSumMat(1.0, in, 0.0); } else { average_row_.AddRowSumMat(1.0/(in.NumRows()+shrinkage_), in, 0.0); } // copy it on the output, out->AddVecToRows(1.0, average_row_, 0.0); } void BackpropagateFnc(const CuMatrixBase &in, const CuMatrixBase &out, const CuMatrixBase &out_diff, CuMatrixBase *in_diff) { // When averaging, a single frame from input influenced all the frames // on the output. So the derivative w.r.t. single input frame is a sum // of the output derivatives, scaled by the averaging constant 1/K. // // In the same time all the input frames of the average influenced // all the output frames. So the loss derivarive is same for all // the input frames coming to the averaging. // // getting the average output diff, average_diff_.Resize(OutputDim()); if (only_summing_) { average_diff_.AddRowSumMat(1.0, out_diff, 0.0); } else { average_diff_.AddRowSumMat(1.0/(out_diff.NumRows()+shrinkage_), out_diff, 0.0); } // copy the derivative into the input diff, (applying gradient-boost!!) in_diff->AddVecToRows(gradient_boost_, average_diff_, 0.0); } private: /// Auxiliary buffer for forward propagation (for average vector), CuVector average_row_; /// Auxiliary buffer for back-propagation (for average vector), CuVector average_diff_; /// Scalar applied on gradient in backpropagation, BaseFloat gradient_boost_; /// Number of 'imaginary' zero-vectors in the average /// (shrinks the average vector for short sentences), BaseFloat shrinkage_; /// Removes normalization term from arithmetic mean (when true). bool only_summing_; }; /** Deprecated!!!, keeping it as Katka Zmolikova used it in JSALT 2015 */ class SentenceAveragingComponent : public UpdatableComponent { public: SentenceAveragingComponent(int32 dim_in, int32 dim_out): UpdatableComponent(dim_in, dim_out), learn_rate_factor_(100.0) { } ~SentenceAveragingComponent() { } Component* Copy() const { return new SentenceAveragingComponent(*this); } ComponentType GetType() const { return kSentenceAveragingComponent; } void InitData(std::istream &is) { // define options std::string nested_nnet_filename; std::string nested_nnet_proto; // parse config std::string token; while (is >> std::ws, !is.eof()) { ReadToken(is, false, &token); /**/ if (token == "") ReadToken(is, false, &nested_nnet_filename); else if (token == "") ReadToken(is, false, &nested_nnet_proto); else if (token == "") ReadBasicType(is, false, &learn_rate_factor_); else KALDI_ERR << "Unknown token " << token << " Typo in config?"; } // initialize (read already prepared nnet from file) KALDI_ASSERT((nested_nnet_proto != "") ^ (nested_nnet_filename != "")); // xor, if (nested_nnet_filename != "") nnet_.Read(nested_nnet_filename); if (nested_nnet_proto != "") nnet_.Init(nested_nnet_proto); // check dims of nested nnet KALDI_ASSERT(InputDim() == nnet_.InputDim()); KALDI_ASSERT(OutputDim() == nnet_.OutputDim() + InputDim()); } void ReadData(std::istream &is, bool binary) { nnet_.Read(is, binary); KALDI_ASSERT(nnet_.InputDim() == InputDim()); KALDI_ASSERT(nnet_.OutputDim() + InputDim() == OutputDim()); } void WriteData(std::ostream &os, bool binary) const { nnet_.Write(os, binary); } int32 NumParams() const { return nnet_.NumParams(); } void GetGradient(VectorBase* gradient) const { KALDI_ERR << "Unimplemented!"; } void GetParams(VectorBase* params) const { KALDI_ASSERT(params->Dim() == NumParams()); Vector params_aux; nnet_.GetParams(¶ms_aux); params->CopyFromVec(params_aux); } void SetParams(const VectorBase& params) { KALDI_ERR << "Unimplemented!"; } std::string Info() const { return std::string("nested_network {\n") + nnet_.Info() + "}\n"; } std::string InfoGradient() const { return std::string("nested_gradient {\n") + nnet_.InfoGradient() + "}\n"; } void PropagateFnc(const CuMatrixBase &in, CuMatrixBase *out) { // Get NN output CuMatrix out_nnet; nnet_.Propagate(in, &out_nnet); // Get the average row (averaging over the time axis): // averaging corresponds to extraction of a 'constant vector' // code for single sentence, int32 num_inputs = in.NumCols(), nnet_outputs = nnet_.OutputDim(), num_frames = out_nnet.NumRows(); CuVector average_row(nnet_outputs); average_row.AddRowSumMat(1.0/num_frames, out_nnet, 0.0); // Forwarding sentence codes along with input features out->ColRange(0, nnet_outputs).AddVecToRows(1.0, average_row, 0.0); out->ColRange(nnet_outputs, num_inputs).CopyFromMat(in); } void BackpropagateFnc(const CuMatrixBase &in, const CuMatrixBase &out, const CuMatrixBase &out_diff, CuMatrixBase *in_diff) { if (in_diff == NULL) return; int32 num_inputs = in.NumCols(), nnet_outputs = nnet_.OutputDim(); in_diff->CopyFromMat(out_diff.ColRange(nnet_outputs, num_inputs)); } void Update(const CuMatrixBase &input, const CuMatrixBase &diff) { // get useful dims, int32 nnet_outputs = nnet_.OutputDim(), num_frames = diff.NumRows(); // Passing the derivative into the nested network. The loss derivative is averaged: // single frame from nested network influenced all the frames in the main network, // so to get the derivative w.r.t. single frame from nested network we sum derivatives // of all frames from main network (and scale by 1/Nframes constant). // // In fact all the frames from nested network influenced all the input frames to main nnet, // so the loss derivarive w.r.t. nested network output is same for all frames in sentence. CuVector average_diff(nnet_outputs); average_diff.AddRowSumMat(1.0 / num_frames, diff.ColRange(0, nnet_outputs), 0.0); CuMatrix nnet_out_diff(num_frames, nnet_outputs); nnet_out_diff.AddVecToRows(1.0, average_diff, 0.0); // nnet_.Backpropagate(nnet_out_diff, NULL); } void SetTrainOptions(const NnetTrainOptions &opts) { UpdatableComponent::SetTrainOptions(opts_); // Pass the train options to the nnet NnetTrainOptions o(opts); o.learn_rate *= learn_rate_factor_; nnet_.SetTrainOptions(opts_); } private: Nnet nnet_; float learn_rate_factor_; }; /* Deprecated */ } // namespace nnet1 } // namespace kaldi #endif // KALDI_NNET_NNET_SENTENCE_AVERAGING_COMPONENT_H_