Blame view

src/chain/chain-numerator.h 5.29 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
  // chain/chain-numerator.h
  
  // Copyright       2015  Johns Hopkins University (Author: Daniel Povey)
  
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  
  #ifndef KALDI_CHAIN_CHAIN_NUMERATOR_H_
  #define KALDI_CHAIN_CHAIN_NUMERATOR_H_
  
  #include <vector>
  #include <map>
  
  #include "base/kaldi-common.h"
  #include "util/common-utils.h"
  #include "fstext/fstext-lib.h"
  #include "tree/context-dep.h"
  #include "lat/kaldi-lattice.h"
  #include "matrix/kaldi-matrix.h"
  #include "hmm/transition-model.h"
  #include "chain/chain-supervision.h"
  #include "cudamatrix/cu-matrix.h"
  #include "cudamatrix/cu-array.h"
  
  namespace kaldi {
  namespace chain {
  
  
  // This class is responsible for the forward-backward of the 'supervision'
  // (numerator) FST.
  //
  // note: the supervision.weight is ignored by this class, you have to apply
  // it externally.
  // Because the supervision FSTs are quite skinny, i.e. have very few paths for
  // each frame, it's feasible to do this computation on the CPU, and that's what
  // we do.  We transfer from/to the GPU only the things that we need.
  
  class NumeratorComputation {
  
   public:
  
    /// Initialize the objcect.  Note: we expect the 'nnet_output' to have the
    /// same number of rows as supervision.num_frames * supervision.num_sequences,
    /// and the same number of columns as the 'label-dim' of the supervision
    /// object (which will be the NumPdfs() of the transition model); but the
    /// ordering of the rows of 'nnet_output' is not the same as the ordering of
    /// frames in paths in the 'supervision' object (which has all frames of the
    /// 1st sequence first, then the 2nd sequence, and so on).  Instead, the
    /// frames in 'nnet_output' are ordered as: first the first frame of each
    /// sequence, then the second frame of each sequence, and so on.  This is more
    /// convenient both because the nnet3 code internally orders them that way,
    /// and because this makes it easier to order things in the way that class
    /// SingleHmmForwardBackward needs (we can just transpose, instead of doing a
    /// 3d tensor rearrangement).
    NumeratorComputation(const Supervision &supervision,
                         const CuMatrixBase<BaseFloat> &nnet_output);
  
    // TODO: we could enable a Viterbi mode.
  
    // Does the forward computation.  Returns the total log-prob multiplied
    // by supervision_.weight.
    BaseFloat Forward();
  
    // Does the backward computation and (efficiently) adds the derivative of the
    // nnet output w.r.t. the (log-prob times supervision_.weight times
    // deriv_weight) to 'nnet_output_deriv'.
    void Backward(CuMatrixBase<BaseFloat> *nnet_output_deriv);
  
   private:
  
    const Supervision &supervision_;
  
    // state times of supervision_.fst.
    std::vector<int32> fst_state_times_;
  
  
    // the exp of the neural net output.
    const CuMatrixBase<BaseFloat> &nnet_output_;
  
  
    // 'fst_output_indexes' contains an entry for each arc in the supervision FST, in
    // the order you'd get them if you visit each arc of each state in order.
    // the contents of fst_output_indexes_ are indexes into nnet_output_indexes_
    // and nnet_logprobs_.
    std::vector<int32> fst_output_indexes_;
  
    // nnet_output_indexes is a list of (row, column) indexes that we need to look
    // up in nnet_output_ for the forward-backward computation.  The order is
    // arbitrary, but indexes into this vector appear in fst_output_indexes;
    // and it's important that each pair only appear once (in order for the
    // derivatives to be summed properly).
    CuArray<Int32Pair> nnet_output_indexes_;
  
    // the log-probs obtained from lookup in the nnet output, on the CPU.  This
    // vector has the same size as nnet_output_indexes_.  In the backward
    // computation, the storage is re-used for derivatives.
    Vector<BaseFloat> nnet_logprobs_;
  
    // derivatives w.r.t. the nnet logprobs.  These can be interpreted as
    // occupation probabilities.
    Vector<BaseFloat> nnet_logprob_derivs_;
  
    // The log-alpha value (forward probability) for each state in the lattices.
    Vector<double> log_alpha_;
  
    // The total pseudo-log-likelihood from the forward-backward.
    double tot_log_prob_;
  
    // The log-beta value (backward probability) for each state in the lattice
    Vector<double> log_beta_;
  
    // This function creates fst_output_indexes_ and nnet_output_indexes_.
    void ComputeLookupIndexes();
  
    // convert time-index in the FST to a row-index in the nnet-output (to account
    // for the fact that the sequences are interleaved in the nnet-output).
    inline int32 ComputeRowIndex(int32 t, int32 frames_per_sequence,
                                 int32 num_sequences) {
      return t / frames_per_sequence +
          num_sequences * (t % frames_per_sequence);
    }
  
  };
  
  
  
  
  }  // namespace chain
  }  // namespace kaldi
  
  #endif  // KALDI_CHAIN_CHAIN_NUMERATOR_H_