Blame view

src/nnet/nnet-component.h 10.5 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
  // nnet/nnet-component.h
  
  // Copyright 2011-2016  Brno University of Technology (Author: Karel Vesely)
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  
  
  #ifndef KALDI_NNET_NNET_COMPONENT_H_
  #define KALDI_NNET_NNET_COMPONENT_H_
  
  #include <iostream>
  #include <string>
  
  #include "base/kaldi-common.h"
  #include "matrix/matrix-lib.h"
  #include "cudamatrix/cu-matrix.h"
  #include "cudamatrix/cu-vector.h"
  #include "nnet/nnet-trnopts.h"
  
  namespace kaldi {
  namespace nnet1 {
  
  /**
   * Abstract class, building block of the network.
   * It is able to propagate (PropagateFnc: compute the output based on its input)
   * and backpropagate (BackpropagateFnc: i.e. transform loss derivative w.r.t. output to derivative w.r.t. the input)
   * the formulas are implemented in descendant classes (AffineTransform,Sigmoid,Softmax,...).
   */
  class Component {
   /// Component type identification mechanism,
   public:
    /// Types of Components,
    typedef enum {
      kUnknown = 0x0,
  
      kUpdatableComponent = 0x0100,
      kAffineTransform,
      kLinearTransform,
      kConvolutionalComponent,
      kLstmProjected,
      kBlstmProjected,
      kRecurrentComponent,
  
      kActivationFunction = 0x0200,
      kSoftmax,
      kHiddenSoftmax,
      kBlockSoftmax,
      kSigmoid,
      kTanh,
      kParametricRelu,
      kDropout,
      kLengthNormComponent,
  
      kTranform = 0x0400,
      kRbm,
      kSplice,
      kCopy,
      kTranspose,
      kBlockLinearity,
      kAddShift,
      kRescale,
  
      kKlHmm = 0x0800,
      kSentenceAveragingComponent, /* deprecated */
      kSimpleSentenceAveragingComponent,
      kAveragePoolingComponent,
      kMaxPoolingComponent,
      kFramePoolingComponent,
      kParallelComponent,
      kMultiBasisComponent
    } ComponentType;
  
    /// A pair of type and marker,
    struct key_value {
      const Component::ComponentType key;
      const char *value;
    };
  
    /// The table with pairs of Component types and markers
    /// (defined in nnet-component.cc),
    static const struct key_value kMarkerMap[];
  
    /// Converts component type to marker,
    static const char* TypeToMarker(ComponentType t);
  
    /// Converts marker to component type (case insensitive),
    static ComponentType MarkerToType(const std::string &s);
  
   /// Generic interface of a component,
   public:
    Component(int32 input_dim, int32 output_dim):
      input_dim_(input_dim),
      output_dim_(output_dim)
    { }
  
    virtual ~Component()
    { }
  
    /// Copy component (deep copy),
    virtual Component* Copy() const = 0;
  
    /// Get Type Identification of the component,
    virtual ComponentType GetType() const = 0;
  
    /// Check if componeny has 'Updatable' interface (trainable components),
    virtual bool IsUpdatable() const {
      return false;
    }
  
    /// Check if component has 'Recurrent' interface (trainable and recurrent),
    virtual bool IsMultistream() const {
      return false;
    }
  
    /// Get the dimension of the input,
    int32 InputDim() const {
      return input_dim_;
    }
  
    /// Get the dimension of the output,
    int32 OutputDim() const {
      return output_dim_;
    }
  
    /// Perform forward-pass propagation 'in' -> 'out',
    void Propagate(const CuMatrixBase<BaseFloat> &in, CuMatrix<BaseFloat> *out);
  
    /// Perform backward-pass propagation 'out_diff' -> 'in_diff'.
    /// Note: 'in' and 'out' will be used only sometimes...
    void Backpropagate(const CuMatrixBase<BaseFloat> &in,
                       const CuMatrixBase<BaseFloat> &out,
                       const CuMatrixBase<BaseFloat> &out_diff,
                       CuMatrix<BaseFloat> *in_diff);
  
    /// Initialize component from a line in config file,
    static Component* Init(const std::string &conf_line);
  
    /// Read the component from a stream (static method),
    static Component* Read(std::istream &is, bool binary);
  
    /// Write the component to a stream,
    void Write(std::ostream &os, bool binary) const;
  
    /// Print some additional info (after <ComponentName> and the dims),
    virtual std::string Info() const { return ""; }
  
    /// Print some additional info about gradient (after <...> and dims),
    virtual std::string InfoGradient() const { return ""; }
  
  
   /// Abstract interface for propagation/backpropagation
   protected:
    /// Forward pass transformation (to be implemented by descending class...)
    virtual void PropagateFnc(const CuMatrixBase<BaseFloat> &in,
                              CuMatrixBase<BaseFloat> *out) = 0;
  
    /// Backward pass transformation (to be implemented by descending class...)
    virtual void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in,
                                  const CuMatrixBase<BaseFloat> &out,
                                  const CuMatrixBase<BaseFloat> &out_diff,
                                  CuMatrixBase<BaseFloat> *in_diff) = 0;
  
   /// Virtual interface for initialization and I/O,
   protected:
    /// Initialize internal data of a component
    virtual void InitData(std::istream &is) { }
  
    /// Reads the component content
    virtual void ReadData(std::istream &is, bool binary) { }
  
    /// Writes the component content
    virtual void WriteData(std::ostream &os, bool binary) const { }
  
   /// Data members,
   protected:
    int32 input_dim_;  ///< Dimension of the input of the Component,
    int32 output_dim_;  ///< Dimension of the output of the Component,
  
   /// Private members (descending classes cannot call this),
   private:
    /// Create a new intance of component,
    static Component* NewComponentOfType(
      ComponentType t, int32 input_dim, int32 output_dim
    );
  };
  
  
  /**
   * Class UpdatableComponent is a Component which has trainable parameters,
   * it contains SGD training hyper-parameters in NnetTrainOptions.
   * The constants 'learning_rate_coef_' and 'bias_learn_rate_coef_'
   * are separate, and should be stored by ::WriteData(...),
   */
  class UpdatableComponent : public Component {
   public:
    UpdatableComponent(int32 input_dim, int32 output_dim):
      Component(input_dim, output_dim),
      learn_rate_coef_(1.0),
      bias_learn_rate_coef_(1.0)
    { }
  
    virtual ~UpdatableComponent()
    { }
  
    /// Check if contains trainable parameters,
    bool IsUpdatable() const {
      return true;
    }
  
    /// Number of trainable parameters,
    virtual int32 NumParams() const = 0;
  
    /// Get gradient reshaped as a vector,
    virtual void GetGradient(VectorBase<BaseFloat> *gradient) const = 0;
  
    /// Get the trainable parameters reshaped as a vector,
    virtual void GetParams(VectorBase<BaseFloat> *params) const = 0;
  
    /// Set the trainable parameters from, reshaped as a vector,
    virtual void SetParams(const VectorBase<BaseFloat> &params) = 0;
  
    /// Compute gradient and update parameters,
    virtual void Update(const CuMatrixBase<BaseFloat> &input,
                        const CuMatrixBase<BaseFloat> &diff) = 0;
  
    /// Set the training options to the component,
    virtual void SetTrainOptions(const NnetTrainOptions &opts) {
      opts_ = opts;
    }
  
    /// Get the training options from the component,
    const NnetTrainOptions& GetTrainOptions() const {
      return opts_;
    }
  
    /// Set the learn-rate coefficient,
    virtual void SetLearnRateCoef(BaseFloat val) {
      learn_rate_coef_ = val;
    }
  
    /// Set the learn-rate coefficient for bias,
    virtual void SetBiasLearnRateCoef(BaseFloat val) {
      bias_learn_rate_coef_ = val;
    }
  
    /// Initialize the content of the component by the 'line' from the prototype,
    virtual void InitData(std::istream &is) = 0;
  
   protected:
    /// Option-class with training hyper-parameters,
    NnetTrainOptions opts_;
  
    /// Scalar applied to learning rate for weight matrices
    /// (to be used in ::Update method),
    BaseFloat learn_rate_coef_;
  
    /// Scalar applied to learning rate for bias
    /// (to be used in ::Update method),
    BaseFloat bias_learn_rate_coef_;
  };
  
  
  /**
   * Class MultistreamComponent is an extension of UpdatableComponent
   * for recurrent networks, which are trained with parallel sequences.
   */
  class MultistreamComponent : public UpdatableComponent {
   public:
    MultistreamComponent(int32 input_dim, int32 output_dim):
      UpdatableComponent(input_dim, output_dim)
    { }
  
    bool IsMultistream() const {
      return true;
    }
  
    virtual void SetSeqLengths(const std::vector<int32>& sequence_lengths) {
      sequence_lengths_ = sequence_lengths;
    }
  
    int32 NumStreams() const {
      return std::max<int32>(1, sequence_lengths_.size());
    }
  
    /// Optional function to reset the transfer of context (not used for BLSTMs
    virtual void ResetStreams(const std::vector<int32>& stream_reset_flag)
    { }
  
   protected:
    std::vector<int32> sequence_lengths_;
  };
  
  
  /*
   * Inline methods for ::Component,
   */
  inline void Component::Propagate(const CuMatrixBase<BaseFloat> &in,
                                   CuMatrix<BaseFloat> *out) {
    // Check the dims
    if (input_dim_ != in.NumCols()) {
      KALDI_ERR << "Non-matching dims on the input of " << TypeToMarker(GetType())
                << " component. The input-dim is " << input_dim_
                << ", the data had " << in.NumCols() << " dims.";
    }
    // Allocate target buffer
    out->Resize(in.NumRows(), output_dim_, kSetZero);  // reset
    // Call the propagation implementation of the component
    PropagateFnc(in, out);
  }
  
  inline void Component::Backpropagate(const CuMatrixBase<BaseFloat> &in,
                                       const CuMatrixBase<BaseFloat> &out,
                                       const CuMatrixBase<BaseFloat> &out_diff,
                                       CuMatrix<BaseFloat> *in_diff) {
    // Check the dims,
    if (OutputDim() != out_diff.NumCols()) {
      KALDI_ERR << "Non-matching dims! Component output dim " << OutputDim()
                << ", the dim of output derivatives " << out_diff.NumCols();
    }
  
    int32 num_frames = out_diff.NumRows();
    KALDI_ASSERT(num_frames == in.NumRows());
    KALDI_ASSERT(num_frames == out.NumRows());
  
    KALDI_ASSERT(InputDim() == in.NumCols());
    KALDI_ASSERT(OutputDim() == out.NumCols());
  
    // Allocate target buffer,
    KALDI_ASSERT(in_diff != NULL);
    in_diff->Resize(num_frames, InputDim(), kSetZero);  // reset,
  
    // Call the 'virtual' backprop function,
    BackpropagateFnc(in, out, out_diff, in_diff);
  }
  
  
  }  // namespace nnet1
  }  // namespace kaldi
  
  
  #endif  // KALDI_NNET_NNET_COMPONENT_H_