Blame view

src/nnet3/nnet-diagnostics.h 8.71 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
  // nnet3/nnet-diagnostics.h
  
  // Copyright    2015  Johns Hopkins University (author: Daniel Povey)
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  #ifndef KALDI_NNET3_NNET_DIAGNOSTICS_H_
  #define KALDI_NNET3_NNET_DIAGNOSTICS_H_
  
  #include "nnet3/nnet-example.h"
  #include "nnet3/nnet-computation.h"
  #include "nnet3/nnet-compute.h"
  #include "nnet3/nnet-optimize.h"
  #include "nnet3/nnet-example-utils.h"
  #include "nnet3/nnet-training.h"
  
  namespace kaldi {
  namespace nnet3 {
  
  
  struct SimpleObjectiveInfo {
    double tot_weight;
    double tot_objective;
    SimpleObjectiveInfo(): tot_weight(0.0),
                           tot_objective(0.0) { }
  };
  
  /* This is used to store more detailed information about the objective,
   * which will be used to compute accuracy per dimension.
   * This might be sensible only for classification tasks.
   */
  struct PerDimObjectiveInfo: public SimpleObjectiveInfo {
    // Counts for each of the classes in the output.
    // In the simplest cases, this might be the number of frames for each class.
    Vector<BaseFloat> tot_weight_vec;
  
    // Objective contribution per-class
    Vector<BaseFloat> tot_objective_vec;
  };
  
  
  struct NnetComputeProbOptions {
    bool debug_computation;
    bool compute_deriv;
    bool compute_accuracy;
    // note: the component stats, if stored, will be stored in the derivative nnet
    // (c.f. GetDeriv()) if compute_deriv is true; otherwise, you should use the
    // constructor of NnetComputeProb that takes a pointer to the nnet, and the
    // stats will be stored there.
    bool store_component_stats;
    
    bool compute_per_dim_accuracy;
  
    NnetOptimizeOptions optimize_config;
    NnetComputeOptions compute_config;
    CachingOptimizingCompilerOptions compiler_config;
    NnetComputeProbOptions():
        debug_computation(false),
        compute_deriv(false),
        compute_accuracy(true),
        store_component_stats(false),
        compute_per_dim_accuracy(false) { }
    void Register(OptionsItf *opts) {
      // compute_deriv is not included in the command line options
      // because it's not relevant for nnet3-compute-prob.
      // store_component_stats is not included in the command line
      // options because it's not relevant for nnet3-compute-prob.
      opts->Register("debug-computation", &debug_computation, "If true, turn on "
                     "debug for the actual computation (very verbose!)");
      opts->Register("compute-accuracy", &compute_accuracy, "If true, compute "
                     "accuracy values as well as objective functions");
      opts->Register("compute-per-dim-accuracy", &compute_per_dim_accuracy,
                     "If true, compute accuracy values per-dim");
  
      // register the optimization options with the prefix "optimization".
      ParseOptions optimization_opts("optimization", opts);
      optimize_config.Register(&optimization_opts);
      // register the compiler options with the prefix "compiler".
      ParseOptions compiler_opts("compiler", opts);
      compiler_config.Register(&compiler_opts);
      // register the compute options with the prefix "computation".
      ParseOptions compute_opts("computation", opts);
      compute_config.Register(&compute_opts);
    }
  };
  
  
  /** This class is for computing cross-entropy and accuracy values in a neural
      network, for diagnostics.
      Note: because we put a "logsoftmax" component in the nnet, the actual
      objective function becomes linear at the output, but the printed messages
      reflect the fact that it's the cross-entropy objective.
   */
  class NnetComputeProb {
   public:
    // does not store a reference to 'config' but does store one to 'nnet'.
    NnetComputeProb(const NnetComputeProbOptions &config,
                    const Nnet &nnet);
  
    // This version of the constructor may only be called if
    // config.store_component_stats == true and config.compute_deriv == false;
    // it means it will store the component stats in 'nnet'.  In this
    // case you should call ZeroComponentStats(nnet) first if you want
    // the stats to be zeroed first.
    NnetComputeProb(const NnetComputeProbOptions &config,
                    Nnet *nnet);
  
  
    // Reset the likelihood stats, and the derivative stats (if computed).
    void Reset();
  
    // compute objective on one minibatch.
    void Compute(const NnetExample &eg);
  
    // Prints out the final stats, and return true if there was a nonzero count.
    bool PrintTotalStats() const;
  
    // returns the objective-function info for this output name (e.g. "output"),
    // or NULL if there is no such info.
    const SimpleObjectiveInfo *GetObjective(const std::string &output_name) const;
  
    // This function returns the total objective over all output nodes recorded here, and
    // outputs to 'tot_weight' the total weight (typically the number of frames)
    // corresponding to it.
    double GetTotalObjective(double *tot_weight) const;
  
    // if config.compute_deriv == true, returns a reference to the
    // computed derivative.  Otherwise crashes.
    const Nnet &GetDeriv() const;
  
    ~NnetComputeProb();
   private:
    void ProcessOutputs(const NnetExample &eg,
                        NnetComputer *computer);
  
    NnetComputeProbOptions config_;
    const Nnet &nnet_;
  
    bool deriv_nnet_owned_;
    Nnet *deriv_nnet_;
    CachingOptimizingCompiler compiler_;
  
    // this is only for diagnostics.
    int32 num_minibatches_processed_;
  
    unordered_map<std::string, SimpleObjectiveInfo, StringHasher> objf_info_;
  
    unordered_map<std::string, PerDimObjectiveInfo, StringHasher> accuracy_info_;
  };
  
  
  /**
     This function computes the frame accuracy for this minibatch.  It interprets
     the supervision information in "supervision" as labels or soft labels; it
     picks the maximum element in each row and treats that as the label for
     purposes of computing the accuracy (in situations where you would care about
     the accuracy, there will normally be just one nonzero label).  The
     hypothesized labels are computed by taking the neural net output (supplied as
     a CuMatrix), and finding the maximum element in each row.
     See also the function ComputeObjectiveFunction, declared in nnet-training.h.
  
     @param [in] supervision  The supervision information (no elements may be
                       negative); only the maximum in each row matters (although
                       we expect that usually there will be just one nonzero
                       element in each row); and the sum of each row is
                       interpreted as a weighting factor (although we expect that
                       this sum will usually be one).
  
    @param [in] nnet_output   The neural net output must have the same dimensions
                       as the supervision.  Only the index of the maximum value in
                       each row matters.  Ties will be broken in an unspecified
                       way.
     @param [out] tot_weight  The sum of the values in the supervision matrix
     @param [out] tot_accuracy  The total accuracy, equal to the sum over all row
                       indexes r such that the maximum column index of row r of
                       supervision and nnet_output is the same, of the sum of 
                       the r'th row of supervision (i.e. the row's weight).
     @param [out] tot_weight_vec  If non-NULL, we write to this location
                      the counts per-class in the supervision matrix.
                      This is expected to have the same dimension as the 
                      corresponding output in the network. 
     @param [out] tot_accuracy_vec  If non-NULL, we write to this location 
                      the accuracy per-class. For index j, 
                      the value is equal to the sum 
                      over all row indexes r such that the maximum column index 
                      of row r of supervision is j and nnet_output is also j,
                      of the sum of the r'th row of supervision 
                      (i.e. the row's weight)
  
  */
  void ComputeAccuracy(const GeneralMatrix &supervision,
                       const CuMatrixBase<BaseFloat> &nnet_output,
                       BaseFloat *tot_weight,
                       BaseFloat *tot_accuracy,
                       VectorBase<BaseFloat> *tot_weight_vec = NULL,
                       VectorBase<BaseFloat> *tot_accuracy_vec = NULL);
  
  
  } // namespace nnet3
  } // namespace kaldi
  
  #endif // KALDI_NNET3_NNET_DIAGNOSTICS_H_