// nnet3/discriminative-training.h // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) // Copyright 2014-2015 Vimal Manohar // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_NNET3_DISCRIMINATIVE_TRAINING_H_ #define KALDI_NNET3_DISCRIMINATIVE_TRAINING_H_ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "fstext/fstext-lib.h" #include "tree/context-dep.h" #include "lat/kaldi-lattice.h" #include "matrix/kaldi-matrix.h" #include "hmm/transition-model.h" #include "nnet3/discriminative-supervision.h" #include "lat/lattice-functions.h" #include "cudamatrix/cu-matrix-lib.h" namespace kaldi { namespace discriminative { /* Options for discriminative training * * Legend: * mmi - Maximum Mutual Information * mpfe - Minimum Phone Frame Error * smbr - State Minimum Bayes Risk * */ struct DiscriminativeOptions { std::string criterion; // one of {"mmi", "mpfe", "smbr"} // If the criterion does not match the supervision // object, the derivatives may not be very accurate BaseFloat acoustic_scale; // e.g. 0.1 bool drop_frames; // for MMI, true if we ignore frames where alignment // pdf-id is not in the lattice. bool one_silence_class; // Affects MPFE and SMBR objectives BaseFloat boost; // for MMI, boosting factor (would be Boosted MMI)... e.g. 0.1. std::string silence_phones_str; // colon-separated list of integer ids of silence phones, // for MPFE and SMBR objectives // Cross-entropy regularization constant. (e.g. try 0.1). If nonzero, // the network is expected to have an output named 'output-xent', which // should have a softmax as its final nonlinearity. BaseFloat xent_regularize; // l2 regularization constant on the 'chain' output; the actual term added to // the objf will be -0.5 times this constant times the squared l2 norm. // (squared so it's additive across the dimensions). e.g. try 0.0005. BaseFloat l2_regularize; // Options for debugging discriminative training // Accumulates gradients wrt nnet outputs bool accumulate_gradients; // Accumulates nnet output bool accumulate_output; // Applicable for debugging discriminative training when accumulate_gradients // or accumulate_output is true int32 num_pdfs; DiscriminativeOptions(): criterion("smbr"), acoustic_scale(0.1), drop_frames(false), one_silence_class(false), boost(0.0), xent_regularize(0.0), l2_regularize(0.0), accumulate_gradients(false), accumulate_output(false), num_pdfs(0) { } void Register(OptionsItf *opts) { opts->Register("criterion", &criterion, "Criterion, 'mmi'|'mpfe'|'smbr', " "determines the objective function to use. Should match " "option used when we created the examples."); opts->Register("acoustic-scale", &acoustic_scale, "Weighting factor to " "apply to acoustic likelihoods."); opts->Register("drop-frames", &drop_frames, "For MMI, if true we drop frames " "with no overlap of num and den pdf-ids"); opts->Register("boost", &boost, "Boosting factor for boosted MMI (e.g. 0.1)"); opts->Register("one-silence-class", &one_silence_class, "If true, newer " "behavior which will tend to reduce insertions " "when using MPFE or SMBR objective"); opts->Register("silence-phones", &silence_phones_str, "For MPFE or SMBR objectives, colon-separated list of " "integer ids of silence phones, e.g. 1:2:3"); opts->Register("l2-regularize", &l2_regularize, "l2 regularization " "constant for 'chain' output " "of the neural net."); opts->Register("xent-regularize", &xent_regularize, "Cross-entropy " "regularization constant for sequence training. If " "nonzero, the network is expected to have an output " "named 'output-xent', which should have a softmax as " "its final nonlinearity."); opts->Register("accumulate-gradients", &accumulate_gradients, "Accumulate gradients wrt nnet output " "for debugging discriminative training"); opts->Register("accumulate-output", &accumulate_output, "Accumulate nnet output " "for debugging discriminative training"); opts->Register("num-pdfs", &num_pdfs, "Number of pdfs; " "applicable when accumulate-output or accumulate-gradients " "is true for discriminative training"); } }; struct DiscriminativeObjectiveInfo { double tot_t; // total number of frames double tot_t_weighted; // total number of frames times weight. double tot_objf; // for 'mmi', the (weighted) denominator likelihood; for // everything else, the objective function. double tot_num_count; // total count of numerator posterior double tot_den_count; // total count of denominator posterior double tot_num_objf; // for 'mmi', the (weighted) numerator likelihood; for // everything else 0 double tot_l2_term; // l2 regularization objective // l2 regularization constant on the 'chain' output; the actual term added to // the objf will be -0.5 times this constant times the squared l2 norm. // (squared so it's additive across the dimensions). e.g. try 0.0005. // Options for debugging discriminative training // Accumulates gradients wrt nnet outputs bool accumulate_gradients; // Accumulates nnet output bool accumulate_output; // Applicable for debugging discriminative training when accumulate_gradients // or accumulate_output is true int32 num_pdfs; // Used to accumulates gradients wrt nnet outputs // when accumulate_gradients is true CuVector gradients; // Used to accumulates output when accumulate_output is true CuVector output; // Print statistics for the criterion void Print(const std::string &criterion, bool print_avg_gradients = false, bool print_avg_output = false) const; // Print all accumulated statistics for debugging void PrintAll(const std::string &criterion) const { Print(criterion, true, true); } // Print the gradient wrt nnet output accumulated for a pdf void PrintAvgGradientForPdf(int32 pdf_id) const; // Add stats from another object void Add(const DiscriminativeObjectiveInfo &other); // Returns the objective function value for the criterion inline double TotalObjf(const std::string &criterion) const { if (criterion == "mmi") return (tot_num_objf - tot_objf); return tot_objf; } // Returns true if accumulate_gradients is true // and the gradients vector has been resized to store the // accumulated gradients inline bool AccumulateGradients() const { return accumulate_gradients && gradients.Dim() > 0; } // Returns true if accumulate_output is true // and the output vector has been resized to store the // accumulated nnet output inline bool AccumulateOutput() const { return accumulate_output && output.Dim() > 0; } // Empty constructor DiscriminativeObjectiveInfo(); // Constructor preparing to gradients or output to be accumulated DiscriminativeObjectiveInfo(int32 num_pdfs); // Constructor from config options DiscriminativeObjectiveInfo(const DiscriminativeOptions &opts); // Reset statistics void Reset(); void Configure(const DiscriminativeOptions &opts); }; /** This function does forward-backward on the numerator and denominator lattices and computes derivates wrt to the output for the specified objective function. @param [in] opts Struct containing options @param [in] tmodel Transition model @param [in] log_priors Vector of log-priors for pdfs @param [in] supervision The supervision object, containing the numerator and denominator paths. The denominator is always a lattice. The numerator is an alignment. @param [in] nnet_output The output of the neural net; dimension must equal ((supervision.num_sequences * supervision.frames_per_sequence) by tmodel.NumPdfs()). @param [out] stats Statistics accumulated during training such as the objective function and the total weight. @param [out] xent_output_deriv If non-NULL, then the xent objective derivative (which equals a posterior from the numerator forward-backward, scaled by the supervision weight) is written to here. This will be used in the cross-entropy regularization code. */ void ComputeDiscriminativeObjfAndDeriv( const DiscriminativeOptions &opts, const TransitionModel &tmodel, const CuVectorBase &log_priors, const DiscriminativeSupervision &supervision, const CuMatrixBase &nnet_output, DiscriminativeObjectiveInfo *stats, CuMatrixBase *nnet_output_deriv, CuMatrixBase *xent_output_deriv); } // namespace discriminative } // namespace kaldi #endif // KALDI_NNET3_DISCRIMINATIVE_TRAINING_H_