Blame view
src/chain/chain-training.h
7.1 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
// chain/chain-training.h // Copyright 2015 Johns Hopkins University (Author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_CHAIN_CHAIN_TRAINING_H_ #define KALDI_CHAIN_CHAIN_TRAINING_H_ #include <vector> #include <map> #include "base/kaldi-common.h" #include "util/common-utils.h" #include "fstext/fstext-lib.h" #include "tree/context-dep.h" #include "lat/kaldi-lattice.h" #include "matrix/kaldi-matrix.h" #include "hmm/transition-model.h" #include "chain/chain-den-graph.h" #include "chain/chain-supervision.h" namespace kaldi { namespace chain { struct ChainTrainingOptions { // l2 regularization constant on the 'chain' output; the actual term added to // the objf will be -0.5 times this constant times the squared l2 norm. // (squared so it's additive across the dimensions). e.g. try 0.0005. BaseFloat l2_regularize; // This is similar to an l2 regularization constant (like l2-regularize) but // applied on the part of the nnet output matrix that exceeds the range // [-30,30]... this is necessary to avoid things regularly going out of the // range that we can do exp() on, since the denominator computation is not in // log space and to avoid NaNs we limit the outputs to the range [-30,30]. BaseFloat out_of_range_regularize; // Coefficient for 'leaky hmm'. This means we have an epsilon-transition from // each state to a special state with probability one, and then another // epsilon-transition from that special state to each state, with probability // leaky_hmm_coefficient times [initial-prob of destination state]. Imagine // we make two copies of each state prior to doing this, version A and version // B, with transition from A to B, so we don't have to consider epsilon loops- // or just imagine the coefficient is small enough that we can ignore the // epsilon loops. // Note: we generally set leaky_hmm_coefficient to 0.1. BaseFloat leaky_hmm_coefficient; // Cross-entropy regularization constant. (e.g. try 0.1). If nonzero, // the network is expected to have an output named 'output-xent', which // should have a softmax as its final nonlinearity. BaseFloat xent_regularize; ChainTrainingOptions(): l2_regularize(0.0), out_of_range_regularize(0.01), leaky_hmm_coefficient(1.0e-05), xent_regularize(0.0) { } void Register(OptionsItf *opts) { opts->Register("l2-regularize", &l2_regularize, "l2 regularization " "constant for 'chain' training, applied to the output " "of the neural net."); opts->Register("out-of-range-regularize", &out_of_range_regularize, "Constant that controls how much we penalize the nnet output " "being outside the range [-30,30]. This is needed because we " "limit it to that range in the denominator computation (which " "is to avoid NaNs because it is not done in log space."); opts->Register("leaky-hmm-coefficient", &leaky_hmm_coefficient, "Coefficient " "that allows transitions from each HMM state to each other " "HMM state, to ensure gradual forgetting of context (can " "improve generalization). For numerical reasons, may not be " "exactly zero."); opts->Register("xent-regularize", &xent_regularize, "Cross-entropy " "regularization constant for 'chain' training. If " "nonzero, the network is expected to have an output " "named 'output-xent', which should have a softmax as " "its final nonlinearity."); } }; /** This function does both the numerator and denominator parts of the 'chain' computation in one call. @param [in] opts Struct containing options @param [in] den_graph The denominator graph, derived from denominator fst. @param [in] supervision The supervision object, containing the supervision paths and constraints on the alignment as an FST @param [in] nnet_output The output of the neural net; dimension must equal ((supervision.num_sequences * supervision.frames_per_sequence) by den_graph.NumPdfs()). The rows are ordered as: all sequences for frame 0; all sequences for frame 1; etc. @param [out] objf The [num - den] objective function computed for this example; you'll want to divide it by 'tot_weight' before displaying it. @param [out] l2_term The l2 regularization term in the objective function, if the --l2-regularize option is used. To be added to 'o @param [out] weight The weight to normalize the objective function by; equals supervision.weight * supervision.num_sequences * supervision.frames_per_sequence. @param [out] nnet_output_deriv The derivative of the objective function w.r.t. the neural-net output. Only written to if non-NULL. You don't have to zero this before passing to this function, we zero it internally. @param [out] xent_output_deriv If non-NULL, then the numerator part of the derivative (which equals a posterior from the numerator forward-backward, scaled by the supervision weight) is written to here (this function will set it to the correct size first; doing it this way reduces the peak memory use). xent_output_deriv will be used in the cross-entropy regularization code; it is also used in computing the cross-entropy objective value. */ void ComputeChainObjfAndDeriv(const ChainTrainingOptions &opts, const DenominatorGraph &den_graph, const Supervision &supervision, const CuMatrixBase<BaseFloat> &nnet_output, BaseFloat *objf, BaseFloat *l2_term, BaseFloat *weight, CuMatrixBase<BaseFloat> *nnet_output_deriv, CuMatrix<BaseFloat> *xent_output_deriv = NULL); } // namespace chain } // namespace kaldi #endif // KALDI_CHAIN_CHAIN_TRAINING_H_ |