// nnet3/nnet-discriminative-training.cc // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) // Copyright 2014-2015 Vimal Manohar // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "nnet3/nnet-discriminative-training.h" #include "nnet3/nnet-utils.h" namespace kaldi { namespace nnet3 { NnetDiscriminativeTrainer::NnetDiscriminativeTrainer( const NnetDiscriminativeOptions &opts, const TransitionModel &tmodel, const VectorBase &priors, Nnet *nnet): opts_(opts), tmodel_(tmodel), log_priors_(priors), nnet_(nnet), compiler_(*nnet, opts_.nnet_config.optimize_config), num_minibatches_processed_(0) { if (opts.nnet_config.zero_component_stats) ZeroComponentStats(nnet); if (opts.nnet_config.momentum == 0.0 && opts.nnet_config.max_param_change == 0.0) { delta_nnet_= NULL; } else { KALDI_ASSERT(opts.nnet_config.momentum >= 0.0 && opts.nnet_config.max_param_change >= 0.0); delta_nnet_ = nnet_->Copy(); ScaleNnet(0.0, delta_nnet_); } if (opts.nnet_config.read_cache != "") { bool binary; Input ki; if (ki.Open(opts.nnet_config.read_cache, &binary)) { compiler_.ReadCache(ki.Stream(), binary); KALDI_LOG << "Read computation cache from " << opts.nnet_config.read_cache; } else { KALDI_WARN << "Could not open cached computation. " "Probably this is the first training iteration."; } } log_priors_.ApplyLog(); } void NnetDiscriminativeTrainer::Train(const NnetDiscriminativeExample &eg) { bool need_model_derivative = true; const NnetTrainerOptions &nnet_config = opts_.nnet_config; bool use_xent_regularization = (opts_.discriminative_config.xent_regularize != 0.0); ComputationRequest request; GetDiscriminativeComputationRequest(*nnet_, eg, need_model_derivative, nnet_config.store_component_stats, use_xent_regularization, need_model_derivative, &request); std::shared_ptr computation = compiler_.Compile(request); NnetComputer computer(nnet_config.compute_config, *computation, *nnet_, (delta_nnet_ == NULL ? nnet_ : delta_nnet_)); // give the inputs to the computer object. computer.AcceptInputs(*nnet_, eg.inputs); computer.Run(); this->ProcessOutputs(eg, &computer); computer.Run(); if (delta_nnet_ != NULL) { BaseFloat scale = (1.0 - nnet_config.momentum); if (nnet_config.max_param_change != 0.0) { BaseFloat param_delta = std::sqrt(DotProduct(*delta_nnet_, *delta_nnet_)) * scale; if (param_delta > nnet_config.max_param_change) { if (param_delta - param_delta != 0.0) { KALDI_WARN << "Infinite parameter change, will not apply."; ScaleNnet(0.0, delta_nnet_); } else { scale *= nnet_config.max_param_change / param_delta; KALDI_LOG << "Parameter change too big: " << param_delta << " > " << "--max-param-change=" << nnet_config.max_param_change << ", scaling by " << nnet_config.max_param_change / param_delta; } } } AddNnet(*delta_nnet_, scale, nnet_); ScaleNnet(nnet_config.momentum, delta_nnet_); } } void NnetDiscriminativeTrainer::ProcessOutputs(const NnetDiscriminativeExample &eg, NnetComputer *computer) { // normally the eg will have just one output named 'output', but // we don't assume this. std::vector::const_iterator iter = eg.outputs.begin(), end = eg.outputs.end(); for (; iter != end; ++iter) { const NnetDiscriminativeSupervision &sup = *iter; int32 node_index = nnet_->GetNodeIndex(sup.name); if (node_index < 0 || !nnet_->IsOutputNode(node_index)) KALDI_ERR << "Network has no output named " << sup.name; const CuMatrixBase &nnet_output = computer->GetOutput(sup.name); CuMatrix nnet_output_deriv(nnet_output.NumRows(), nnet_output.NumCols(), kUndefined); bool use_xent = (opts_.discriminative_config.xent_regularize != 0.0); std::string xent_name = sup.name + "-xent"; // typically "output-xent". CuMatrix xent_deriv; if (use_xent) xent_deriv.Resize(nnet_output.NumRows(), nnet_output.NumCols(), kUndefined); discriminative::DiscriminativeObjectiveInfo stats(opts_.discriminative_config); if (objf_info_.count(sup.name) == 0) { objf_info_[sup.name].stats.Configure(opts_.discriminative_config); objf_info_[sup.name].stats.Reset(); } ComputeDiscriminativeObjfAndDeriv(opts_.discriminative_config, tmodel_, log_priors_, sup.supervision, nnet_output, &stats, &nnet_output_deriv, (use_xent ? &xent_deriv : NULL)); if (use_xent) { // this block computes the cross-entropy objective. const CuMatrixBase &xent_output = computer->GetOutput(xent_name); // at this point, xent_deriv is posteriors derived from the numerator // computation. note, xent_objf has a factor of '.supervision.weight' BaseFloat xent_objf = TraceMatMat(xent_output, xent_deriv, kTrans); if (xent_objf != xent_objf) { BaseFloat default_objf = -10; xent_objf = default_objf; } discriminative::DiscriminativeObjectiveInfo xent_stats; xent_stats.tot_t_weighted = stats.tot_t_weighted; xent_stats.tot_objf = xent_objf; objf_info_[xent_name].UpdateStats(xent_name, "xent", opts_.nnet_config.print_interval, num_minibatches_processed_, xent_stats); } if (opts_.apply_deriv_weights && sup.deriv_weights.Dim() != 0) { CuVector cu_deriv_weights(sup.deriv_weights); nnet_output_deriv.MulRowsVec(cu_deriv_weights); if (use_xent) xent_deriv.MulRowsVec(cu_deriv_weights); } computer->AcceptInput(sup.name, &nnet_output_deriv); objf_info_[sup.name].UpdateStats(sup.name, opts_.discriminative_config.criterion, opts_.nnet_config.print_interval, num_minibatches_processed_++, stats); if (use_xent) { xent_deriv.Scale(opts_.discriminative_config.xent_regularize); computer->AcceptInput(xent_name, &xent_deriv); } } } bool NnetDiscriminativeTrainer::PrintTotalStats() const { unordered_map::const_iterator iter = objf_info_.begin(), end = objf_info_.end(); bool ans = false; for (; iter != end; ++iter) { const std::string &name = iter->first; const DiscriminativeObjectiveFunctionInfo &info = iter->second; bool ret = info.PrintTotalStats(name, opts_.discriminative_config.criterion); ans = ans || ret; } return ans; } void DiscriminativeObjectiveFunctionInfo::UpdateStats( const std::string &output_name, const std::string &criterion, int32 minibatches_per_phase, int32 minibatch_counter, discriminative::DiscriminativeObjectiveInfo this_minibatch_stats) { int32 phase = minibatch_counter / minibatches_per_phase; if (phase != current_phase) { KALDI_ASSERT(phase == current_phase + 1); // or doesn't really make sense. PrintStatsForThisPhase(output_name, criterion, minibatches_per_phase); current_phase = phase; stats_this_phase.Reset(); } stats_this_phase.Add(this_minibatch_stats); stats.Add(this_minibatch_stats); } void DiscriminativeObjectiveFunctionInfo::PrintStatsForThisPhase( const std::string &output_name, const std::string &criterion, int32 minibatches_per_phase) const { int32 start_minibatch = current_phase * minibatches_per_phase, end_minibatch = start_minibatch + minibatches_per_phase - 1; BaseFloat objf = (stats_this_phase.TotalObjf(criterion) / stats_this_phase.tot_t_weighted); KALDI_LOG << "Average objective function for '" << output_name << "' for minibatches " << start_minibatch << '-' << end_minibatch << " is " << objf << " over " << stats_this_phase.tot_t_weighted << " frames."; } bool DiscriminativeObjectiveFunctionInfo::PrintTotalStats(const std::string &name, const std::string &criterion) const { BaseFloat objf = stats.TotalObjf(criterion) /stats.tot_t_weighted; double avg_gradients = (stats.tot_num_count + stats.tot_den_count) / stats.tot_t_weighted; KALDI_LOG << "Average num+den count of stats is " << avg_gradients << " per frame, over " << stats.tot_t_weighted << " frames."; if (stats.tot_l2_term != 0.0) { KALDI_LOG << "Average l2 norm of output per frame is " << (stats.tot_l2_term / stats.tot_t_weighted) << " over " << stats.tot_t_weighted << " frames."; } KALDI_LOG << "Overall average objective function for '" << name << "' is " << objf << " over " << stats.tot_t_weighted << " frames."; KALDI_LOG << "[this line is to be parsed by a script:] " << criterion << "-per-frame=" << objf; return (stats.tot_t_weighted != 0.0); } NnetDiscriminativeTrainer::~NnetDiscriminativeTrainer() { delete delta_nnet_; if (opts_.nnet_config.write_cache != "") { Output ko(opts_.nnet_config.write_cache, opts_.nnet_config.binary_write_cache); compiler_.WriteCache(ko.Stream(), opts_.nnet_config.binary_write_cache); } } } // namespace nnet3 } // namespace kaldi