Blame view

src/nnet3/discriminative-supervision.h 8.97 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
  // nnet3/discriminative-supervision.h
  
  // Copyright 2012-2015  Johns Hopkins University (author: Daniel Povey)
  //           2014-2015  Vimal Manohar
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  #ifndef KALDI_NNET3_DISCRIMINATIVE_SUPERVISION_H
  #define KALDI_NNET3_DISCRIMINATIVE_SUPERVISION_H
  
  #include "util/table-types.h"
  #include "hmm/posterior.h"
  #include "hmm/transition-model.h"
  #include "lat/kaldi-lattice.h"
  
  namespace kaldi {
  namespace discriminative {
  
  
  struct SplitDiscriminativeSupervisionOptions {
    int32 frame_subsampling_factor;
    bool remove_output_symbols;
    bool collapse_transition_ids;
    bool remove_epsilons;
    bool determinize;
    bool minimize; // we'll push and minimize if this is true.
    BaseFloat acoustic_scale;
  
    SplitDiscriminativeSupervisionOptions() :
        remove_output_symbols(true), collapse_transition_ids(true),
        remove_epsilons(true), determinize(true),
        minimize(true), acoustic_scale(0.1) { }
  
    void Register(OptionsItf *opts) {
      opts->Register("collapse-transition-ids", &collapse_transition_ids,
                     "If true, modify the transition-ids on denominator lattice "
                     "so that on each frame, there is just one with any given "
                     "pdf-id. This allows us to determinize and minimize "
                     "more completely.");
      opts->Register("remove-output-symbols", &remove_output_symbols,
                     "Remove output symbols from lattice to convert it to an "
                     "acceptor and make it more determinizable");
      opts->Register("remove-epsilons", &remove_epsilons,
                     "Remove epsilons from the split lattices");
      opts->Register("determinize", &determinize, "If true, we determinize "
                     "lattices (as Lattice) after splitting and possibly minimize");
      opts->Register("minimize", &minimize, "If true, we push and "
                     "minimize lattices (as Lattice) after splitting");
      opts->Register("acoustic-scale", &acoustic_scale,
                     "Scaling factor for acoustic likelihoods (should match the "
                     "value used in discriminative-get-supervision)");
    }
  };
  
  /*
    This file contains some declarations relating to the object we use to
    encode the supervision information for sequence training
  */
  
  // struct DiscriminativeSupervision is the fully-processed information for
  // a whole utterance or (after splitting) part of an utterance.
  struct DiscriminativeSupervision {
    // The weight we assign to this example;
    // this will typically be one, but we include it
    // for the sake of generality.
    BaseFloat weight;
  
    // num_sequences will be 1 if you create a DiscriminativeSupervision object from a single
    // lattice or alignment, but if you combine multiple DiscriminativeSupervision objects
    // the 'num_sequences' is the number of objects that were combined (the
    // lattices get appended).
    int32 num_sequences;
  
    // the number of frames in each sequence of appended objects.  num_frames *
    // num_sequences must equal the path length of any path in the lattices.
    // Technically this information is redundant with the lattices, but it's convenient
    // to have it separately.
    int32 frames_per_sequence;
  
    // The numerator alignment
    // Usually obtained by aligning the reference text with the seed neural
    // network model; can be the best path of generated lattice in the case of
    // semi-supervised training.
    std::vector<int32> num_ali;
  
    // Note: any acoustic
    // likelihoods in the lattices will be
    // recomputed at the time we train.
  
    // The denominator lattice.
    Lattice den_lat;
  
    DiscriminativeSupervision(): weight(1.0), num_sequences(1),
                                 frames_per_sequence(-1) { }
  
    DiscriminativeSupervision(const DiscriminativeSupervision &other);
  
  
    // This function creates a supervision object from numerator alignment
    // and denominator lattice.  The supervision object is used for sequence
    // discriminative training.
    // Topologically sorts the lattice after copying to the supervision object.
    // Returns false when alignment or lattice is empty
    bool Initialize(const std::vector<int32> &alignment,
                    const Lattice &lat,
                    BaseFloat weight);
  
    void Swap(DiscriminativeSupervision *other);
  
    bool operator == (const DiscriminativeSupervision &other) const;
  
    // This function checks that this supervision object satifsies some
    // of the properties we expect of it, and calls KALDI_ERR if not.
    void Check() const;
  
    inline int32 NumFrames() const {
      return num_sequences * frames_per_sequence;
    }
  
    void Write(std::ostream &os, bool binary) const;
    void Read(std::istream &is, bool binary);
  };
  
  // This class is used for splitting something of type
  // DiscriminativeSupervision into
  // multiple pieces corresponding to different frame-ranges.
  class DiscriminativeSupervisionSplitter {
   public:
    typedef fst::ArcTpl<LatticeWeight> LatticeArc;
    typedef fst::VectorFst<LatticeArc> Lattice;
  
    DiscriminativeSupervisionSplitter(
        const SplitDiscriminativeSupervisionOptions &config,
        const TransitionModel &tmodel,
        const DiscriminativeSupervision &supervision);
  
    // A structure used to store the forward and backward scores
    // and state times of a lattice
    struct LatticeInfo {
      // These values are stored in log.
      std::vector<double> alpha;
      std::vector<double> beta;
      std::vector<int32> state_times;
  
      void Check() const;
    };
  
    // Extracts a frame range of the supervision into 'supervision'.
    void GetFrameRange(int32 begin_frame, int32 frames_per_sequence,
                       bool normalize,
                       DiscriminativeSupervision *supervision) const;
  
    // Get the acoustic scaled denominator lattice out for debugging purposes
    inline const Lattice& DenLat() const { return den_lat_; }
  
   private:
  
    // Creates an output lattice covering frames begin_frame <= t < end_frame,
    // assuming that the corresponding state-range that we need to
    // include, begin_state <= s < end_state has been included.
    // (note: the output lattice will also have two special initial and final
    // states).
    // Also does post-processing (RmEpsilon, Determinize,
    // TopSort on the result).  See code for details.
    void CreateRangeLattice(const Lattice &in_lat,
                            const LatticeInfo &scores,
                            int32 begin_frame, int32 end_frame, bool normalize,
                            Lattice *out_lat) const;
  
    // Config options for splitting supervision object
    const SplitDiscriminativeSupervisionOptions &config_;
  
    // Transition model is used by the function
    // CollapseTransitionIds()
    const TransitionModel &tmodel_;
  
    // A reference to the supervision object that we will be splitting
    const DiscriminativeSupervision &supervision_;
  
    // LatticeInfo object for denominator lattice.
    // This will be computed when PrepareLattice function is called.
    LatticeInfo den_lat_scores_;
  
    // Copy of denominator lattice. This is required because the lattice states
    // need to be ordered in breadth-first search order.
    Lattice den_lat_;
  
    // Function to compute lattice scores for a lattice
    void ComputeLatticeScores(const Lattice &lat, LatticeInfo *scores) const;
  
    // Prepare lattice :
    // 1) Order states in breadth-first search order
    // 2) Compute states times, which must be a strictly non-decreasing vector
    // 3) Compute lattice alpha and beta scores
    void PrepareLattice(Lattice *lat, LatticeInfo *scores) const;
  
    // Modifies the transition-ids on lat_ so that on each frame, there is just
    // one with any given pdf-id.  This allows us to determinize and minimize
    // more completely.
    void CollapseTransitionIds(const std::vector<int32> &state_times,
                               Lattice *lat) const;
  
  };
  
  /// This function appends a list of supervision objects to create what will
  /// usually be a single such object, but if the weights and num-frames are not
  /// all the same it will only append Supervision objects where successive ones
  /// have the same weight and num-frames, and if 'compactify' is true.  The
  /// normal use-case for this is when you are combining neural-net examples for
  /// training; appending them like this helps to simplify the training process.
  
  void MergeSupervision(const std::vector<const DiscriminativeSupervision*> &input,
      DiscriminativeSupervision *output_supervision);
  
  
  } // namespace discriminative
  } // namespace kaldi
  
  #endif // KALDI_NNET3_DISCRIMINATIVE_SUPERVISION_H