Blame view

src/nnet3/nnet-discriminative-training.h 4.62 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
  // nnet3/nnet-discriminative-training.h
  
  // Copyright 2012-2015   Johns Hopkins University (author: Daniel Povey)
  //           2014-2015   Vimal Manohar
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  #ifndef KALDI_NNET3_NNET_DISCRIMINATIVE_TRAINING_H_
  #define KALDI_NNET3_NNET_DISCRIMINATIVE_TRAINING_H_
  
  #include "nnet3/nnet-example.h"
  #include "nnet3/nnet-computation.h"
  #include "nnet3/nnet-compute.h"
  #include "nnet3/nnet-optimize.h"
  #include "nnet3/nnet-discriminative-example.h"
  #include "nnet3/nnet-training.h"
  #include "nnet3/discriminative-training.h"
  
  namespace kaldi {
  namespace nnet3 {
  
  struct NnetDiscriminativeOptions {
    NnetTrainerOptions nnet_config;
    discriminative::DiscriminativeOptions discriminative_config;
  
    bool apply_deriv_weights;
  
    NnetDiscriminativeOptions(): apply_deriv_weights(true) { }
  
    void Register(OptionsItf *opts) {
      nnet_config.Register(opts);
      discriminative_config.Register(opts);
      opts->Register("apply-deriv-weights", &apply_deriv_weights,
                     "If true, apply the per-frame derivative weights stored with "
                     "the example.");
    }
  };
  
  // This struct is used in multiple nnet training classes for keeping
  // track of objective function values.
  // Also see struct AccuracyInfo, in nnet-diagnostics.h.
  struct DiscriminativeObjectiveFunctionInfo {
    int32 current_phase;
  
    discriminative::DiscriminativeObjectiveInfo stats;
    discriminative::DiscriminativeObjectiveInfo stats_this_phase;
  
    DiscriminativeObjectiveFunctionInfo():
        current_phase(0) { }
  
    // This function updates the stats and, if the phase has just changed,
    // prints a message indicating progress.  The phase equals
    // minibatch_counter / minibatches_per_phase.  Its only function is to
    // control how frequently we print logging messages.
    void UpdateStats(const std::string &output_name,
                     const std::string &criterion,
                     int32 minibatches_per_phase,
                     int32 minibatch_counter,
                     discriminative::DiscriminativeObjectiveInfo stats);
  
    // Prints stats for the current phase.
    void PrintStatsForThisPhase(const std::string &output_name,
                                const std::string &criterion,
                                int32 minibatches_per_phase) const;
    // Prints total stats, and returns true if total stats' weight was nonzero.
    bool PrintTotalStats(const std::string &output_name,
                         const std::string &criterion) const;
  };
  
  
  /**
     This class is for single-threaded discriminative training of neural nets 
  */
  class NnetDiscriminativeTrainer {
   public:
    NnetDiscriminativeTrainer(const NnetDiscriminativeOptions &config,
                              const TransitionModel &tmodel,
                              const VectorBase<BaseFloat> &priors,
                              Nnet *nnet);
  
    // train on one minibatch.
    void Train(const NnetDiscriminativeExample &eg);
  
    // Prints out the final stats, and return true if there was a nonzero count.
    bool PrintTotalStats() const;
  
    ~NnetDiscriminativeTrainer();
   private:
    void ProcessOutputs(const NnetDiscriminativeExample &eg,
                        NnetComputer *computer);
  
    const NnetDiscriminativeOptions opts_;
  
    const TransitionModel &tmodel_;
    CuVector<BaseFloat> log_priors_;
    
    Nnet *nnet_;
  
    Nnet *delta_nnet_;  // Only used if momentum != 0.0.  nnet representing
                        // accumulated parameter-change (we'd call this
                        // gradient_nnet_, but due to natural-gradient update,
                        // it's better to consider it as a delta-parameter nnet.
    CachingOptimizingCompiler compiler_;
  
    int32 num_minibatches_processed_;
  
    // This code supports multiple output layers, even though in the
    // normal case there will be just one output layer named "output".
    // So we store the objective functions per output layer.
    unordered_map<std::string, DiscriminativeObjectiveFunctionInfo, StringHasher> objf_info_;
  };
  
  
  } // namespace nnet3
  } // namespace kaldi
  
  #endif // KALDI_NNET3_NNET_DISCRIMINATIVE_TRAINING_H_