Blame view
src/nnet3/nnet-discriminative-training.h
4.62 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
// nnet3/nnet-discriminative-training.h // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) // 2014-2015 Vimal Manohar // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_NNET3_NNET_DISCRIMINATIVE_TRAINING_H_ #define KALDI_NNET3_NNET_DISCRIMINATIVE_TRAINING_H_ #include "nnet3/nnet-example.h" #include "nnet3/nnet-computation.h" #include "nnet3/nnet-compute.h" #include "nnet3/nnet-optimize.h" #include "nnet3/nnet-discriminative-example.h" #include "nnet3/nnet-training.h" #include "nnet3/discriminative-training.h" namespace kaldi { namespace nnet3 { struct NnetDiscriminativeOptions { NnetTrainerOptions nnet_config; discriminative::DiscriminativeOptions discriminative_config; bool apply_deriv_weights; NnetDiscriminativeOptions(): apply_deriv_weights(true) { } void Register(OptionsItf *opts) { nnet_config.Register(opts); discriminative_config.Register(opts); opts->Register("apply-deriv-weights", &apply_deriv_weights, "If true, apply the per-frame derivative weights stored with " "the example."); } }; // This struct is used in multiple nnet training classes for keeping // track of objective function values. // Also see struct AccuracyInfo, in nnet-diagnostics.h. struct DiscriminativeObjectiveFunctionInfo { int32 current_phase; discriminative::DiscriminativeObjectiveInfo stats; discriminative::DiscriminativeObjectiveInfo stats_this_phase; DiscriminativeObjectiveFunctionInfo(): current_phase(0) { } // This function updates the stats and, if the phase has just changed, // prints a message indicating progress. The phase equals // minibatch_counter / minibatches_per_phase. Its only function is to // control how frequently we print logging messages. void UpdateStats(const std::string &output_name, const std::string &criterion, int32 minibatches_per_phase, int32 minibatch_counter, discriminative::DiscriminativeObjectiveInfo stats); // Prints stats for the current phase. void PrintStatsForThisPhase(const std::string &output_name, const std::string &criterion, int32 minibatches_per_phase) const; // Prints total stats, and returns true if total stats' weight was nonzero. bool PrintTotalStats(const std::string &output_name, const std::string &criterion) const; }; /** This class is for single-threaded discriminative training of neural nets */ class NnetDiscriminativeTrainer { public: NnetDiscriminativeTrainer(const NnetDiscriminativeOptions &config, const TransitionModel &tmodel, const VectorBase<BaseFloat> &priors, Nnet *nnet); // train on one minibatch. void Train(const NnetDiscriminativeExample &eg); // Prints out the final stats, and return true if there was a nonzero count. bool PrintTotalStats() const; ~NnetDiscriminativeTrainer(); private: void ProcessOutputs(const NnetDiscriminativeExample &eg, NnetComputer *computer); const NnetDiscriminativeOptions opts_; const TransitionModel &tmodel_; CuVector<BaseFloat> log_priors_; Nnet *nnet_; Nnet *delta_nnet_; // Only used if momentum != 0.0. nnet representing // accumulated parameter-change (we'd call this // gradient_nnet_, but due to natural-gradient update, // it's better to consider it as a delta-parameter nnet. CachingOptimizingCompiler compiler_; int32 num_minibatches_processed_; // This code supports multiple output layers, even though in the // normal case there will be just one output layer named "output". // So we store the objective functions per output layer. unordered_map<std::string, DiscriminativeObjectiveFunctionInfo, StringHasher> objf_info_; }; } // namespace nnet3 } // namespace kaldi #endif // KALDI_NNET3_NNET_DISCRIMINATIVE_TRAINING_H_ |