// nnet/nnet-loss.h // Copyright 2011-2015 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_NNET_NNET_LOSS_H_ #define KALDI_NNET_NNET_LOSS_H_ #include #include #include "base/kaldi-common.h" #include "base/timer.h" #include "util/kaldi-holder.h" #include "itf/options-itf.h" #include "cudamatrix/cu-matrix.h" #include "cudamatrix/cu-vector.h" #include "cudamatrix/cu-array.h" #include "hmm/posterior.h" namespace kaldi { namespace nnet1 { struct LossOptions { int32 loss_report_frames; ///< Report loss value every 'report_interval' frames, LossOptions(): loss_report_frames(5*3600*100) // 5h, { } void Register(OptionsItf *opts) { opts->Register("loss-report-frames", &loss_report_frames, "Report loss per blocks of N frames (0 = no reports)"); } }; class LossItf { public: LossItf(LossOptions& opts) { opts_ = opts; } virtual ~LossItf() { } /// Evaluate cross entropy using target-matrix (supports soft labels), virtual void Eval(const VectorBase &frame_weights, const CuMatrixBase &net_out, const CuMatrixBase &target, CuMatrix *diff) = 0; /// Evaluate cross entropy using target-posteriors (supports soft labels), virtual void Eval(const VectorBase &frame_weights, const CuMatrixBase &net_out, const Posterior &target, CuMatrix *diff) = 0; /// Generate string with error report, virtual std::string Report() = 0; /// Get loss value (frame average), virtual BaseFloat AvgLoss() = 0; protected: LossOptions opts_; Timer timer_; }; class Xent : public LossItf { public: Xent(LossOptions &opts): LossItf(opts), frames_progress_(0.0), xentropy_progress_(0.0), entropy_progress_(0.0), elapsed_seconds_(0.0) { } ~Xent() { } /// Evaluate cross entropy using target-matrix (supports soft labels), void Eval(const VectorBase &frame_weights, const CuMatrixBase &net_out, const CuMatrixBase &target, CuMatrix *diff); /// Evaluate cross entropy using target-posteriors (supports soft labels), void Eval(const VectorBase &frame_weights, const CuMatrixBase &net_out, const Posterior &target, CuMatrix *diff); /// Generate string with error report, std::string Report(); /// Generate string with per-class error report, std::string ReportPerClass(); /// Get loss value (frame average), BaseFloat AvgLoss() { if (frames_.Sum() == 0) return 0.0; return (xentropy_.Sum() - entropy_.Sum()) / frames_.Sum(); } private: // main stats collected per target-class, CuVector frames_; Vector correct_; CuVector xentropy_; CuVector entropy_; // partial results during training, double frames_progress_; double xentropy_progress_; double entropy_progress_; std::vector loss_vec_; double elapsed_seconds_; // weigting buffer, CuVector frame_weights_; CuVector target_sum_; // loss computation buffers, CuMatrix tgt_mat_; CuMatrix frames_aux_; CuMatrix xentropy_aux_; CuMatrix entropy_aux_; // frame classification buffers, CuArray max_id_out_; CuArray max_id_tgt_; }; class Mse : public LossItf { public: Mse(LossOptions &opts): LossItf(opts), frames_(0.0), loss_(0.0), frames_progress_(0.0), loss_progress_(0.0) { } ~Mse() { } /// Evaluate mean square error using target-matrix, void Eval(const VectorBase &frame_weights, const CuMatrixBase& net_out, const CuMatrixBase& target, CuMatrix* diff); /// Evaluate mean square error using target-posteior, void Eval(const VectorBase &frame_weights, const CuMatrixBase& net_out, const Posterior& target, CuMatrix* diff); /// Generate string with error report std::string Report(); /// Get loss value (frame average), BaseFloat AvgLoss() { if (frames_ == 0) return 0.0; return loss_ / frames_; } private: double frames_; double loss_; double frames_progress_; double loss_progress_; std::vector loss_vec_; CuVector frame_weights_; CuMatrix tgt_mat_; CuMatrix diff_pow_2_; }; class MultiTaskLoss : public LossItf { public: MultiTaskLoss(LossOptions &opts): LossItf(opts) { } ~MultiTaskLoss() { while (loss_vec_.size() > 0) { delete loss_vec_.back(); loss_vec_.pop_back(); } } /// Initialize from string, the format for string 's' is : /// 'multitask,,,,...,,,' /// /// Practically it can look like this : /// 'multitask,xent,2456,1.0,mse,440,0.001' void InitFromString(const std::string& s); /// Evaluate mean square error using target-matrix, void Eval(const VectorBase &frame_weights, const CuMatrixBase& net_out, const CuMatrixBase& target, CuMatrix* diff) { KALDI_ERR << "This is not supposed to be called!"; } /// Evaluate mean square error using target-posteior, void Eval(const VectorBase &frame_weights, const CuMatrixBase& net_out, const Posterior& target, CuMatrix* diff); /// Generate string with error report std::string Report(); /// Get loss value (frame average), BaseFloat AvgLoss(); private: std::vector loss_vec_; std::vector loss_dim_; std::vector loss_weights_; std::vector loss_dim_offset_; CuMatrix tgt_mat_; }; } // namespace nnet1 } // namespace kaldi #endif // KALDI_NNET_NNET_LOSS_H_