// nnet3/nnet-training.cc // Copyright 2015 Johns Hopkins University (author: Daniel Povey) // 2015 Xiaohui Zhang // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "nnet3/nnet-training.h" #include "nnet3/nnet-utils.h" namespace kaldi { namespace nnet3 { NnetTrainer::NnetTrainer(const NnetTrainerOptions &config, Nnet *nnet): config_(config), nnet_(nnet), compiler_(*nnet, config_.optimize_config, config_.compiler_config), num_minibatches_processed_(0), max_change_stats_(*nnet), srand_seed_(RandInt(0, 100000)) { if (config.zero_component_stats) ZeroComponentStats(nnet); KALDI_ASSERT(config.momentum >= 0.0 && config.max_param_change >= 0.0 && config.backstitch_training_interval > 0); delta_nnet_ = nnet_->Copy(); ScaleNnet(0.0, delta_nnet_); if (config_.read_cache != "") { bool binary; Input ki; if (ki.Open(config_.read_cache, &binary)) { compiler_.ReadCache(ki.Stream(), binary); KALDI_LOG << "Read computation cache from " << config_.read_cache; } else { KALDI_WARN << "Could not open cached computation. " "Probably this is the first training iteration."; } } } void NnetTrainer::Train(const NnetExample &eg) { bool need_model_derivative = true; ComputationRequest request; GetComputationRequest(*nnet_, eg, need_model_derivative, config_.store_component_stats, &request); std::shared_ptr computation = compiler_.Compile(request); if (config_.backstitch_training_scale > 0.0 && num_minibatches_processed_ % config_.backstitch_training_interval == srand_seed_ % config_.backstitch_training_interval) { // backstitch training is incompatible with momentum > 0 KALDI_ASSERT(config_.momentum == 0.0); FreezeNaturalGradient(true, delta_nnet_); bool is_backstitch_step1 = true; srand(srand_seed_ + num_minibatches_processed_); ResetGenerators(nnet_); TrainInternalBackstitch(eg, *computation, is_backstitch_step1); FreezeNaturalGradient(false, delta_nnet_); // un-freeze natural gradient is_backstitch_step1 = false; srand(srand_seed_ + num_minibatches_processed_); ResetGenerators(nnet_); TrainInternalBackstitch(eg, *computation, is_backstitch_step1); } else { // conventional training TrainInternal(eg, *computation); } if (num_minibatches_processed_ == 0) { ConsolidateMemory(nnet_); ConsolidateMemory(delta_nnet_); } num_minibatches_processed_++; } void NnetTrainer::TrainInternal(const NnetExample &eg, const NnetComputation &computation) { // note: because we give the 1st arg (nnet_) as a pointer to the // constructor of 'computer', it will use that copy of the nnet to // store stats. NnetComputer computer(config_.compute_config, computation, nnet_, delta_nnet_); // give the inputs to the computer object. computer.AcceptInputs(*nnet_, eg.io); computer.Run(); this->ProcessOutputs(false, eg, &computer); computer.Run(); // If relevant, add in the part of the gradient that comes from L2 // regularization. ApplyL2Regularization(*nnet_, GetNumNvalues(eg.io, false) * config_.l2_regularize_factor, delta_nnet_); // Update the parameters of nnet bool success = UpdateNnetWithMaxChange( *delta_nnet_, config_.max_param_change, 1.0, 1.0 - config_.momentum, nnet_, &max_change_stats_); // Scale down the batchnorm stats (keeps them fresh... this affects what // happens when we use the model with batchnorm test-mode set). ScaleBatchnormStats(config_.batchnorm_stats_scale, nnet_); // The following will only do something if we have a LinearComponent // or AffineComponent with orthonormal-constraint set to a nonzero value. ConstrainOrthonormal(nnet_); // Scale deta_nnet if (success) ScaleNnet(config_.momentum, delta_nnet_); else ScaleNnet(0.0, delta_nnet_); } void NnetTrainer::TrainInternalBackstitch(const NnetExample &eg, const NnetComputation &computation, bool is_backstitch_step1) { // note: because we give the 1st arg (nnet_) as a pointer to the // constructor of 'computer', it will use that copy of the nnet to // store stats. NnetComputer computer(config_.compute_config, computation, nnet_, delta_nnet_); // give the inputs to the computer object. computer.AcceptInputs(*nnet_, eg.io); computer.Run(); bool is_backstitch_step2 = !is_backstitch_step1; this->ProcessOutputs(is_backstitch_step2, eg, &computer); computer.Run(); BaseFloat max_change_scale, scale_adding; if (is_backstitch_step1) { // max-change is scaled by backstitch_training_scale; // delta_nnet is scaled by -backstitch_training_scale when added to nnet; max_change_scale = config_.backstitch_training_scale; scale_adding = -config_.backstitch_training_scale; } else { // max-change is scaled by 1 + backstitch_training_scale; // delta_nnet is scaled by 1 + backstitch_training_scale when added to nnet; max_change_scale = 1.0 + config_.backstitch_training_scale; scale_adding = 1.0 + config_.backstitch_training_scale; // If relevant, add in the part of the gradient that comes from L2 // regularization. It may not be optimally inefficient to do it on both // passes of the backstitch, like we do here, but it probably minimizes // any harmful interactions with the max-change. ApplyL2Regularization(*nnet_, 1.0 / scale_adding * GetNumNvalues(eg.io, false) * config_.l2_regularize_factor, delta_nnet_); } // Updates the parameters of nnet UpdateNnetWithMaxChange( *delta_nnet_, config_.max_param_change, max_change_scale, scale_adding, nnet_, &max_change_stats_); if (is_backstitch_step1) { // The following will only do something if we have a LinearComponent or // AffineComponent with orthonormal-constraint set to a nonzero value. We // choose to do this only on the 1st backstitch step, for efficiency. ConstrainOrthonormal(nnet_); } if (!is_backstitch_step1) { // Scale down the batchnorm stats (keeps them fresh... this affects what // happens when we use the model with batchnorm test-mode set). Do this // after backstitch step 2 so that the stats are scaled down before we start // the next minibatch. ScaleBatchnormStats(config_.batchnorm_stats_scale, nnet_); } ScaleNnet(0.0, delta_nnet_); } void NnetTrainer::ProcessOutputs(bool is_backstitch_step2, const NnetExample &eg, NnetComputer *computer) { // normally the eg will have just one output named 'output', but // we don't assume this. // In backstitch training, the output-name with the "_backstitch" suffix is // the one computed after the first, backward step of backstitch. const std::string suffix = (is_backstitch_step2 ? "_backstitch" : ""); std::vector::const_iterator iter = eg.io.begin(), end = eg.io.end(); for (; iter != end; ++iter) { const NnetIo &io = *iter; int32 node_index = nnet_->GetNodeIndex(io.name); KALDI_ASSERT(node_index >= 0); if (nnet_->IsOutputNode(node_index)) { ObjectiveType obj_type = nnet_->GetNode(node_index).u.objective_type; BaseFloat tot_weight, tot_objf; bool supply_deriv = true; ComputeObjectiveFunction(io.features, obj_type, io.name, supply_deriv, computer, &tot_weight, &tot_objf); objf_info_[io.name + suffix].UpdateStats(io.name + suffix, config_.print_interval, num_minibatches_processed_, tot_weight, tot_objf); } } } bool NnetTrainer::PrintTotalStats() const { unordered_map::const_iterator iter = objf_info_.begin(), end = objf_info_.end(); std::vector > all_pairs; for (; iter != end; ++iter) all_pairs.push_back(std::pair( iter->first, &(iter->second))); // ensure deterministic order of these names (this will matter in situations // where a script greps for the objective from the log). std::sort(all_pairs.begin(), all_pairs.end()); bool ans = false; for (size_t i = 0; i < all_pairs.size(); i++) { const std::string &name = all_pairs[i].first; const ObjectiveFunctionInfo &info = *(all_pairs[i].second); bool ok = info.PrintTotalStats(name); ans = ans || ok; } max_change_stats_.Print(*nnet_); return ans; } void ObjectiveFunctionInfo::UpdateStats( const std::string &output_name, int32 minibatches_per_phase, int32 minibatch_counter, BaseFloat this_minibatch_weight, BaseFloat this_minibatch_tot_objf, BaseFloat this_minibatch_tot_aux_objf) { int32 phase = minibatch_counter / minibatches_per_phase; if (phase != current_phase) { KALDI_ASSERT(phase > current_phase); PrintStatsForThisPhase(output_name, minibatches_per_phase, phase); current_phase = phase; tot_weight_this_phase = 0.0; tot_objf_this_phase = 0.0; tot_aux_objf_this_phase = 0.0; minibatches_this_phase = 0; } minibatches_this_phase++; tot_weight_this_phase += this_minibatch_weight; tot_objf_this_phase += this_minibatch_tot_objf; tot_aux_objf_this_phase += this_minibatch_tot_aux_objf; tot_weight += this_minibatch_weight; tot_objf += this_minibatch_tot_objf; tot_aux_objf += this_minibatch_tot_aux_objf; } void ObjectiveFunctionInfo::PrintStatsForThisPhase( const std::string &output_name, int32 minibatches_per_phase, int32 phase) const { int32 start_minibatch = current_phase * minibatches_per_phase, end_minibatch = phase * minibatches_per_phase - 1; if (tot_aux_objf_this_phase == 0.0) { if (minibatches_per_phase == minibatches_this_phase) { KALDI_LOG << "Average objective function for '" << output_name << "' for minibatches " << start_minibatch << '-' << end_minibatch << " is " << (tot_objf_this_phase / tot_weight_this_phase) << " over " << tot_weight_this_phase << " frames."; } else { KALDI_LOG << "Average objective function for '" << output_name << " using " << minibatches_this_phase << " minibatches in minibatch range " << start_minibatch << '-' << end_minibatch << " is " << (tot_objf_this_phase / tot_weight_this_phase) << " over " << tot_weight_this_phase << " frames."; } } else { BaseFloat objf = (tot_objf_this_phase / tot_weight_this_phase), aux_objf = (tot_aux_objf_this_phase / tot_weight_this_phase), sum_objf = objf + aux_objf; if (minibatches_per_phase == minibatches_this_phase) { KALDI_LOG << "Average objective function for '" << output_name << "' for minibatches " << start_minibatch << '-' << end_minibatch << " is " << objf << " + " << aux_objf << " = " << sum_objf << " over " << tot_weight_this_phase << " frames."; } else { KALDI_LOG << "Average objective function for '" << output_name << "' using " << minibatches_this_phase << " minibatches in minibatch range " << start_minibatch << '-' << end_minibatch << " is " << objf << " + " << aux_objf << " = " << sum_objf << " over " << tot_weight_this_phase << " frames."; } } } bool ObjectiveFunctionInfo::PrintTotalStats(const std::string &name) const { BaseFloat objf = (tot_objf / tot_weight), aux_objf = (tot_aux_objf / tot_weight), sum_objf = objf + aux_objf; if (tot_aux_objf == 0.0) { KALDI_LOG << "Overall average objective function for '" << name << "' is " << (tot_objf / tot_weight) << " over " << tot_weight << " frames."; } else { KALDI_LOG << "Overall average objective function for '" << name << "' is " << objf << " + " << aux_objf << " = " << sum_objf << " over " << tot_weight << " frames."; } KALDI_LOG << "[this line is to be parsed by a script:] " << "log-prob-per-frame=" << objf; return (tot_weight != 0.0); } NnetTrainer::~NnetTrainer() { if (config_.write_cache != "") { Output ko(config_.write_cache, config_.binary_write_cache); compiler_.WriteCache(ko.Stream(), config_.binary_write_cache); KALDI_LOG << "Wrote computation cache to " << config_.write_cache; } delete delta_nnet_; } void ComputeObjectiveFunction(const GeneralMatrix &supervision, ObjectiveType objective_type, const std::string &output_name, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, BaseFloat *tot_objf) { const CuMatrixBase &output = computer->GetOutput(output_name); if (output.NumCols() != supervision.NumCols()) KALDI_ERR << "Nnet versus example output dimension (num-classes) " << "mismatch for '" << output_name << "': " << output.NumCols() << " (nnet) vs. " << supervision.NumCols() << " (egs)\n"; switch (objective_type) { case kLinear: { // objective is x * y. switch (supervision.Type()) { case kSparseMatrix: { const SparseMatrix &post = supervision.GetSparseMatrix(); CuSparseMatrix cu_post(post); // The cross-entropy objective is computed by a simple dot product, // because after the LogSoftmaxLayer, the output is already in the form // of log-likelihoods that are normalized to sum to one. *tot_weight = cu_post.Sum(); *tot_objf = TraceMatSmat(output, cu_post, kTrans); if (supply_deriv) { CuMatrix output_deriv(output.NumRows(), output.NumCols(), kUndefined); cu_post.CopyToMat(&output_deriv); computer->AcceptInput(output_name, &output_deriv); } break; } case kFullMatrix: { // there is a redundant matrix copy in here if we're not using a GPU // but we don't anticipate this code branch being used in many cases. CuMatrix cu_post(supervision.GetFullMatrix()); *tot_weight = cu_post.Sum(); *tot_objf = TraceMatMat(output, cu_post, kTrans); if (supply_deriv) computer->AcceptInput(output_name, &cu_post); break; } case kCompressedMatrix: { Matrix post; supervision.GetMatrix(&post); CuMatrix cu_post; cu_post.Swap(&post); *tot_weight = cu_post.Sum(); *tot_objf = TraceMatMat(output, cu_post, kTrans); if (supply_deriv) computer->AcceptInput(output_name, &cu_post); break; } } break; } case kQuadratic: { // objective is -0.5 (x - y)^2 CuMatrix diff(supervision.NumRows(), supervision.NumCols(), kUndefined); diff.CopyFromGeneralMat(supervision); diff.AddMat(-1.0, output); *tot_weight = diff.NumRows(); *tot_objf = -0.5 * TraceMatMat(diff, diff, kTrans); if (supply_deriv) computer->AcceptInput(output_name, &diff); break; } default: KALDI_ERR << "Objective function type " << objective_type << " not handled."; } } } // namespace nnet3 } // namespace kaldi