Blame view
src/nnet3/discriminative-supervision.h
8.97 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
// nnet3/discriminative-supervision.h // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) // 2014-2015 Vimal Manohar // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_NNET3_DISCRIMINATIVE_SUPERVISION_H #define KALDI_NNET3_DISCRIMINATIVE_SUPERVISION_H #include "util/table-types.h" #include "hmm/posterior.h" #include "hmm/transition-model.h" #include "lat/kaldi-lattice.h" namespace kaldi { namespace discriminative { struct SplitDiscriminativeSupervisionOptions { int32 frame_subsampling_factor; bool remove_output_symbols; bool collapse_transition_ids; bool remove_epsilons; bool determinize; bool minimize; // we'll push and minimize if this is true. BaseFloat acoustic_scale; SplitDiscriminativeSupervisionOptions() : remove_output_symbols(true), collapse_transition_ids(true), remove_epsilons(true), determinize(true), minimize(true), acoustic_scale(0.1) { } void Register(OptionsItf *opts) { opts->Register("collapse-transition-ids", &collapse_transition_ids, "If true, modify the transition-ids on denominator lattice " "so that on each frame, there is just one with any given " "pdf-id. This allows us to determinize and minimize " "more completely."); opts->Register("remove-output-symbols", &remove_output_symbols, "Remove output symbols from lattice to convert it to an " "acceptor and make it more determinizable"); opts->Register("remove-epsilons", &remove_epsilons, "Remove epsilons from the split lattices"); opts->Register("determinize", &determinize, "If true, we determinize " "lattices (as Lattice) after splitting and possibly minimize"); opts->Register("minimize", &minimize, "If true, we push and " "minimize lattices (as Lattice) after splitting"); opts->Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods (should match the " "value used in discriminative-get-supervision)"); } }; /* This file contains some declarations relating to the object we use to encode the supervision information for sequence training */ // struct DiscriminativeSupervision is the fully-processed information for // a whole utterance or (after splitting) part of an utterance. struct DiscriminativeSupervision { // The weight we assign to this example; // this will typically be one, but we include it // for the sake of generality. BaseFloat weight; // num_sequences will be 1 if you create a DiscriminativeSupervision object from a single // lattice or alignment, but if you combine multiple DiscriminativeSupervision objects // the 'num_sequences' is the number of objects that were combined (the // lattices get appended). int32 num_sequences; // the number of frames in each sequence of appended objects. num_frames * // num_sequences must equal the path length of any path in the lattices. // Technically this information is redundant with the lattices, but it's convenient // to have it separately. int32 frames_per_sequence; // The numerator alignment // Usually obtained by aligning the reference text with the seed neural // network model; can be the best path of generated lattice in the case of // semi-supervised training. std::vector<int32> num_ali; // Note: any acoustic // likelihoods in the lattices will be // recomputed at the time we train. // The denominator lattice. Lattice den_lat; DiscriminativeSupervision(): weight(1.0), num_sequences(1), frames_per_sequence(-1) { } DiscriminativeSupervision(const DiscriminativeSupervision &other); // This function creates a supervision object from numerator alignment // and denominator lattice. The supervision object is used for sequence // discriminative training. // Topologically sorts the lattice after copying to the supervision object. // Returns false when alignment or lattice is empty bool Initialize(const std::vector<int32> &alignment, const Lattice &lat, BaseFloat weight); void Swap(DiscriminativeSupervision *other); bool operator == (const DiscriminativeSupervision &other) const; // This function checks that this supervision object satifsies some // of the properties we expect of it, and calls KALDI_ERR if not. void Check() const; inline int32 NumFrames() const { return num_sequences * frames_per_sequence; } void Write(std::ostream &os, bool binary) const; void Read(std::istream &is, bool binary); }; // This class is used for splitting something of type // DiscriminativeSupervision into // multiple pieces corresponding to different frame-ranges. class DiscriminativeSupervisionSplitter { public: typedef fst::ArcTpl<LatticeWeight> LatticeArc; typedef fst::VectorFst<LatticeArc> Lattice; DiscriminativeSupervisionSplitter( const SplitDiscriminativeSupervisionOptions &config, const TransitionModel &tmodel, const DiscriminativeSupervision &supervision); // A structure used to store the forward and backward scores // and state times of a lattice struct LatticeInfo { // These values are stored in log. std::vector<double> alpha; std::vector<double> beta; std::vector<int32> state_times; void Check() const; }; // Extracts a frame range of the supervision into 'supervision'. void GetFrameRange(int32 begin_frame, int32 frames_per_sequence, bool normalize, DiscriminativeSupervision *supervision) const; // Get the acoustic scaled denominator lattice out for debugging purposes inline const Lattice& DenLat() const { return den_lat_; } private: // Creates an output lattice covering frames begin_frame <= t < end_frame, // assuming that the corresponding state-range that we need to // include, begin_state <= s < end_state has been included. // (note: the output lattice will also have two special initial and final // states). // Also does post-processing (RmEpsilon, Determinize, // TopSort on the result). See code for details. void CreateRangeLattice(const Lattice &in_lat, const LatticeInfo &scores, int32 begin_frame, int32 end_frame, bool normalize, Lattice *out_lat) const; // Config options for splitting supervision object const SplitDiscriminativeSupervisionOptions &config_; // Transition model is used by the function // CollapseTransitionIds() const TransitionModel &tmodel_; // A reference to the supervision object that we will be splitting const DiscriminativeSupervision &supervision_; // LatticeInfo object for denominator lattice. // This will be computed when PrepareLattice function is called. LatticeInfo den_lat_scores_; // Copy of denominator lattice. This is required because the lattice states // need to be ordered in breadth-first search order. Lattice den_lat_; // Function to compute lattice scores for a lattice void ComputeLatticeScores(const Lattice &lat, LatticeInfo *scores) const; // Prepare lattice : // 1) Order states in breadth-first search order // 2) Compute states times, which must be a strictly non-decreasing vector // 3) Compute lattice alpha and beta scores void PrepareLattice(Lattice *lat, LatticeInfo *scores) const; // Modifies the transition-ids on lat_ so that on each frame, there is just // one with any given pdf-id. This allows us to determinize and minimize // more completely. void CollapseTransitionIds(const std::vector<int32> &state_times, Lattice *lat) const; }; /// This function appends a list of supervision objects to create what will /// usually be a single such object, but if the weights and num-frames are not /// all the same it will only append Supervision objects where successive ones /// have the same weight and num-frames, and if 'compactify' is true. The /// normal use-case for this is when you are combining neural-net examples for /// training; appending them like this helps to simplify the training process. void MergeSupervision(const std::vector<const DiscriminativeSupervision*> &input, DiscriminativeSupervision *output_supervision); } // namespace discriminative } // namespace kaldi #endif // KALDI_NNET3_DISCRIMINATIVE_SUPERVISION_H |