Blame view
src/online2/online-ivector-feature.h
25.2 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 |
// online2/online-ivector-feature.h // Copyright 2013-2014 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_ONLINE2_ONLINE_IVECTOR_FEATURE_H_ #define KALDI_ONLINE2_ONLINE_IVECTOR_FEATURE_H_ #include <string> #include <vector> #include <deque> #include "matrix/matrix-lib.h" #include "util/common-utils.h" #include "base/kaldi-error.h" #include "itf/online-feature-itf.h" #include "gmm/diag-gmm.h" #include "feat/online-feature.h" #include "ivector/ivector-extractor.h" #include "decoder/lattice-faster-online-decoder.h" namespace kaldi { /// @addtogroup onlinefeat OnlineFeatureExtraction /// @{ /// @file /// This file contains code for online iVector extraction in a form compatible /// with OnlineFeatureInterface. It's used in online-nnet2-feature-pipeline.h. /// This class includes configuration variables relating to the online iVector /// extraction, but not including configuration for the "base feature", /// i.e. MFCC/PLP/filterbank, which is an input to this feature. This /// configuration class can be used from the command line, but before giving it /// to the code we create a config class called /// OnlineIvectorExtractionInfo which contains the actual configuration /// classes as well as various objects that are needed. The principle is that /// any code should be callable from other code, so we didn't want to force /// configuration classes to be read from disk. struct OnlineIvectorExtractionConfig { std::string lda_mat_rxfilename; // to read the LDA+MLLT matrix std::string global_cmvn_stats_rxfilename; // to read matrix of global CMVN // stats std::string splice_config_rxfilename; // to read OnlineSpliceOptions std::string cmvn_config_rxfilename; // to read in OnlineCmvnOptions std::string diag_ubm_rxfilename; // reads type DiagGmm. std::string ivector_extractor_rxfilename; // reads type IvectorExtractor // the following four configuration values should in principle match those // given to the script extract_ivectors_online.sh, although none of them are // super-critical. int32 ivector_period; // How frequently we re-estimate iVectors. int32 num_gselect; // maximum number of posteriors to use per frame for // iVector extractor. BaseFloat min_post; // pruning threshold for posteriors for the iVector // extractor. BaseFloat posterior_scale; // Scale on posteriors used for iVector // extraction; can be interpreted as the inverse // of a scale on the log-prior. BaseFloat max_count; // Maximum stats count we allow before we start scaling // down stats (if nonzero).. this prevents us getting // atypical-looking iVectors for very long utterances. // Interpret this as a number of frames times // posterior_scale, typically 1/10 of a frame count. int32 num_cg_iters; // set to 15. I don't believe this is very important, so it's // not configurable from the command line for now. // If use_most_recent_ivector is true, we always return the most recent // available iVector rather than the one for the current frame. This means // that if audio is coming in faster than we can process it, we will return a // more accurate iVector. bool use_most_recent_ivector; // If true, always read ahead to NumFramesReady() when getting iVector stats. bool greedy_ivector_extractor; // max_remembered_frames is the largest number of frames it will remember // between utterances of the same speaker; this affects the output of // GetAdaptationState(), and has the effect of limiting the number of frames // of both the CMVN stats and the iVector stats. Setting this to a smaller // value means the adaptation is less constrained by previous utterances // (assuming you provided info from a previous utterance of the same speaker // by calling SetAdaptationState()). BaseFloat max_remembered_frames; OnlineIvectorExtractionConfig(): ivector_period(10), num_gselect(5), min_post(0.025), posterior_scale(0.1), max_count(0.0), num_cg_iters(15), use_most_recent_ivector(true), greedy_ivector_extractor(false), max_remembered_frames(1000) { } void Register(OptionsItf *opts) { opts->Register("lda-matrix", &lda_mat_rxfilename, "Filename of LDA matrix, " "e.g. final.mat; used for iVector extraction. "); opts->Register("global-cmvn-stats", &global_cmvn_stats_rxfilename, "(Extended) filename for global CMVN stats, used in iVector " "extraction, obtained for example from " "'matrix-sum scp:data/train/cmvn.scp -', only used for " "iVector extraction"); opts->Register("cmvn-config", &cmvn_config_rxfilename, "Configuration " "file for online CMVN features (e.g. conf/online_cmvn.conf)," "only used for iVector extraction. Contains options " "as for the program 'apply-cmvn-online'"); opts->Register("splice-config", &splice_config_rxfilename, "Configuration file " "for frame splicing (--left-context and --right-context " "options); used for iVector extraction."); opts->Register("diag-ubm", &diag_ubm_rxfilename, "Filename of diagonal UBM " "used to obtain posteriors for iVector extraction, e.g. " "final.dubm"); opts->Register("ivector-extractor", &ivector_extractor_rxfilename, "Filename of iVector extractor, e.g. final.ie"); opts->Register("ivector-period", &ivector_period, "Frequency with which " "we extract iVectors for neural network adaptation"); opts->Register("num-gselect", &num_gselect, "Number of Gaussians to select " "for iVector extraction"); opts->Register("min-post", &min_post, "Threshold for posterior pruning in " "iVector extraction"); opts->Register("posterior-scale", &posterior_scale, "Scale for posteriors in " "iVector extraction (may be viewed as inverse of prior scale)"); opts->Register("max-count", &max_count, "Maximum data count we allow before " "we start scaling the stats down (if nonzero)... helps to make " "iVectors from long utterances look more typical. Interpret " "as a frame-count times --posterior-scale, typically 1/10 of " "a number of frames. Suggest 100."); opts->Register("use-most-recent-ivector", &use_most_recent_ivector, "If true, " "always use most recent available iVector, rather than the " "one for the designated frame."); opts->Register("greedy-ivector-extractor", &greedy_ivector_extractor, "If " "true, 'read ahead' as many frames as we currently have available " "when extracting the iVector. May improve iVector quality."); opts->Register("max-remembered-frames", &max_remembered_frames, "The maximum " "number of frames of adaptation history that we carry through " "to later utterances of the same speaker (having a finite " "number allows the speaker adaptation state to change over " "time). Interpret as a real frame count, i.e. not a count " "scaled by --posterior-scale."); } }; /// This struct contains various things that are needed (as const references) /// by class OnlineIvectorExtractor. struct OnlineIvectorExtractionInfo { Matrix<BaseFloat> lda_mat; // LDA+MLLT matrix. Matrix<double> global_cmvn_stats; // Global CMVN stats. OnlineCmvnOptions cmvn_opts; // Options for online CMN/CMVN computation. OnlineSpliceOptions splice_opts; // Options for frame splicing // (--left-context,--right-context) DiagGmm diag_ubm; IvectorExtractor extractor; // the following configuration variables are copied from // OnlineIvectorExtractionConfig, see comments there. int32 ivector_period; int32 num_gselect; BaseFloat min_post; BaseFloat posterior_scale; BaseFloat max_count; int32 num_cg_iters; bool use_most_recent_ivector; bool greedy_ivector_extractor; BaseFloat max_remembered_frames; OnlineIvectorExtractionInfo(const OnlineIvectorExtractionConfig &config); void Init(const OnlineIvectorExtractionConfig &config); // This constructor creates a version of this object where everything // is empty or zero. OnlineIvectorExtractionInfo(); void Check() const; private: KALDI_DISALLOW_COPY_AND_ASSIGN(OnlineIvectorExtractionInfo); }; /// This class stores the adaptation state from the online iVector extractor, /// which can help you to initialize the adaptation state for the next utterance /// of the same speaker in a more informed way. struct OnlineIvectorExtractorAdaptationState { // CMVN state for the features used to get posteriors for iVector extraction; // online CMVN is not used for the features supplied to the neural net, // instead the iVector is used. // Adaptation state for online CMVN (used for getting posteriors for iVector) OnlineCmvnState cmvn_state; /// Stats for online iVector estimation. OnlineIvectorEstimationStats ivector_stats; /// This constructor initializes adaptation-state with no prior speaker history. OnlineIvectorExtractorAdaptationState(const OnlineIvectorExtractionInfo &info): cmvn_state(info.global_cmvn_stats), ivector_stats(info.extractor.IvectorDim(), info.extractor.PriorOffset(), info.max_count) { } /// Copy constructor OnlineIvectorExtractorAdaptationState( const OnlineIvectorExtractorAdaptationState &other); /// Scales down the stats if needed to ensure the number of frames in the /// speaker-specific CMVN stats does not exceed max_remembered_frames /// and the data-count in the iVector stats does not exceed /// max_remembered_frames * posterior_scale. [the posterior_scale /// factor is necessary because those stats have already been scaled /// by that factor.] void LimitFrames(BaseFloat max_remembered_frames, BaseFloat posterior_scale); void Write(std::ostream &os, bool binary) const; void Read(std::istream &is, bool binary); }; /// OnlineIvectorFeature is an online feature-extraction class that's responsible /// for extracting iVectors from raw features such as MFCC, PLP or filterbank. /// Internally it processes the raw features using two different pipelines, one /// online-CMVN+splice+LDA, and one just splice+LDA. It gets GMM posteriors from /// the CMVN-normalized features, and with those and the unnormalized features /// it obtains iVectors. class OnlineIvectorFeature: public OnlineFeatureInterface { public: /// Constructor. base_feature is for example raw MFCC or PLP or filterbank /// features, whatever was used to train the iVector extractor. /// "info" contains all the configuration information as well as /// things like the iVector extractor that we won't be modifying. /// Caution: the class keeps a const reference to "info", so don't /// delete it while this class or others copied from it still exist. explicit OnlineIvectorFeature(const OnlineIvectorExtractionInfo &info, OnlineFeatureInterface *base_feature); // This version of the constructor accepts per-frame weights (relates to // downweighting silence). This is intended for use in offline operation, // i.e. during training. [will implement this when needed.] //explicit OnlineIvectorFeature(const OnlineIvectorExtractionInfo &info, // std::vector<BaseFloat> frame_weights, //OnlineFeatureInterface *base_feature); // Member functions from OnlineFeatureInterface: /// Dim() will return the iVector dimension. virtual int32 Dim() const; virtual bool IsLastFrame(int32 frame) const; virtual int32 NumFramesReady() const; virtual BaseFloat FrameShiftInSeconds() const; virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat); /// Set the adaptation state to a particular value, e.g. reflecting previous /// utterances of the same speaker; this will generally be called after /// constructing a new instance of this class. void SetAdaptationState( const OnlineIvectorExtractorAdaptationState &adaptation_state); /// Get the adaptation state; you may want to call this before destroying this /// object, to get adaptation state that can be used to improve decoding of /// later utterances of this speaker. void GetAdaptationState( OnlineIvectorExtractorAdaptationState *adaptation_state) const; virtual ~OnlineIvectorFeature(); // Some diagnostics (not present in generic interface): // UBM log-like per frame: BaseFloat UbmLogLikePerFrame() const; // Objective improvement per frame from iVector estimation, versus default iVector // value, measured at utterance end. BaseFloat ObjfImprPerFrame() const; // returns number of frames seen (but not counting the posterior-scale). BaseFloat NumFrames() const { return ivector_stats_.NumFrames() / info_.posterior_scale; } // If you are downweighting silence, you can call // OnlineSilenceWeighting::GetDeltaWeights and supply the output to this class // using UpdateFrameWeights(). The reason why this call happens outside this // class, rather than this class pulling in the data weights, relates to // multi-threaded operation and also from not wanting this class to have // excessive dependencies. // // You must either always call this as soon as new data becomes available // (ideally just after calling AcceptWaveform), or never call it for the // lifetime of this object. void UpdateFrameWeights( const std::vector<std::pair<int32, BaseFloat> > &delta_weights); private: // This accumulates i-vector stats for a set of frames, specified as pairs // (t, weight). The weights do not have to be positive. (In the online // silence-weighting that we do, negative weights can occur if we change our // minds about the assignment of a frame as silence vs. non-silence). void UpdateStatsForFrames( const std::vector<std::pair<int32, BaseFloat> > &frame_weights); // Returns a modified version of info_.min_post, which is opts_.min_post if // weight is 1.0 or -1.0, but gets larger if fabs(weight) is small... but no // larger than 0.99. (This is an efficiency thing, to not bother processing // very small counts). BaseFloat GetMinPost(BaseFloat weight) const; // This is the original UpdateStatsUntilFrame that is called when there is // no data-weighting involved. void UpdateStatsUntilFrame(int32 frame); // This is the new UpdateStatsUntilFrame that is called when there is // data-weighting (i.e. when the user has been calling UpdateFrameWeights()). void UpdateStatsUntilFrameWeighted(int32 frame); void PrintDiagnostics() const; const OnlineIvectorExtractionInfo &info_; OnlineFeatureInterface *base_; // The feature this is built on top of // (e.g. MFCC); not owned here OnlineFeatureInterface *lda_; // LDA on top of raw+splice features. OnlineCmvn *cmvn_; // the CMVN that we give to the lda_normalized_. OnlineFeatureInterface *lda_normalized_; // LDA on top of CMVN+splice // the following is the pointers to OnlineFeatureInterface objects that are // owned here and which we need to delete. std::vector<OnlineFeatureInterface*> to_delete_; /// the iVector estimation stats OnlineIvectorEstimationStats ivector_stats_; /// num_frames_stats_ is the number of frames of data we have already /// accumulated from this utterance and put in ivector_stats_. Each frame t < /// num_frames_stats_ is in the stats. In case you are doing the /// silence-weighted iVector estimation, with UpdateFrameWeights() being /// called, this variable is still used but you may later have to revisit /// earlier frames to adjust their weights... see the code. int32 num_frames_stats_; /// delta_weights_ is written to by UpdateFrameWeights, /// in the case where the iVector estimation is silence-weighted using the decoder /// traceback. Its elements are consumed by UpdateStatsUntilFrameWeighted(). /// We provide std::greater<std::pair<int32, BaseFloat> > > as the comparison type /// (default is std::less) so that the lowest-numbered frame, not the highest-numbered /// one, will be returned by top(). std::priority_queue<std::pair<int32, BaseFloat>, std::vector<std::pair<int32, BaseFloat> >, std::greater<std::pair<int32, BaseFloat> > > delta_weights_; /// this is only used for validating that the frame-weighting code is not buggy. std::vector<BaseFloat> current_frame_weight_debug_; /// delta_weights_provided_ is set to true if UpdateFrameWeights was ever called; it's /// used to detect wrong usage of this class. bool delta_weights_provided_; /// The following is also used to detect wrong usage of this class; it's set /// to true if UpdateStatsUntilFrame() was ever called. bool updated_with_no_delta_weights_; /// if delta_weights_ was ever called, this keeps track of the most recent /// frame that ever had a weight. It's mostly for detecting errors. int32 most_recent_frame_with_weight_; /// The following is only needed for diagnostics. double tot_ubm_loglike_; /// Most recently estimated iVector, will have been /// estimated at the greatest time t where t <= num_frames_stats_ and /// t % info_.ivector_period == 0. Vector<double> current_ivector_; /// if info_.use_most_recent_ivector == false, we need to store /// the iVector we estimated each info_.ivector_period frames so that /// GetFrame() can return the iVector that was active on that frame. /// ivectors_history_[i] contains the iVector we estimated on /// frame t = i * info_.ivector_period. std::vector<Vector<BaseFloat>* > ivectors_history_; }; struct OnlineSilenceWeightingConfig { std::string silence_phones_str; // The weighting factor that we apply to silence phones in the iVector // extraction. This option is only relevant if the --silence-phones option is // set. BaseFloat silence_weight; // Transition-ids that get repeated at least this many times (if // max_state_duration > 0) are treated as silence. BaseFloat max_state_duration; // This is the scale that we apply to data that we don't yet have a decoder // traceback for, in the online silence BaseFloat new_data_weight; bool Active() const { return !silence_phones_str.empty() && silence_weight != 1.0; } OnlineSilenceWeightingConfig(): silence_weight(1.0), max_state_duration(-1) { } void Register(OptionsItf *opts) { opts->Register("silence-phones", &silence_phones_str, "(RE weighting in " "iVector estimation for online decoding) List of integer ids of " "silence phones, separated by colons (or commas). Data that " "(according to the traceback of the decoder) corresponds to " "these phones will be downweighted by --silence-weight."); opts->Register("silence-weight", &silence_weight, "(RE weighting in " "iVector estimation for online decoding) Weighting factor for frames " "that the decoder trace-back identifies as silence; only " "relevant if the --silence-phones option is set."); opts->Register("max-state-duration", &max_state_duration, "(RE weighting in " "iVector estimation for online decoding) Maximum allowed " "duration of a single transition-id; runs with durations longer " "than this will be weighted down to the silence-weight."); } // e.g. prefix = "ivector-silence-weighting" void RegisterWithPrefix(std::string prefix, OptionsItf *opts) { ParseOptions po_prefix(prefix, opts); this->Register(&po_prefix); } }; // This class is responsible for keeping track of the best-path traceback from // the decoder (efficiently) and computing a weighting of the data based on the // classification of frames as silence (or not silence)... also with a duration // limitation, so data from a very long run of the same transition-id will get // weighted down. (this is often associated with misrecognition or silence). class OnlineSilenceWeighting { public: // Note: you would initialize a new copy of this object for each new // utterance. // The frame-subsampling-factor is used for newer nnet3 models, especially // chain models, when the frame-rate of the decoder is different from the // frame-rate of the input features. E.g. you might set it to 3 for such // models. OnlineSilenceWeighting(const TransitionModel &trans_model, const OnlineSilenceWeightingConfig &config, int32 frame_subsampling_factor = 1); bool Active() const { return config_.Active(); } // This should be called before GetDeltaWeights, so this class knows about the // traceback info from the decoder. It records the traceback information from // the decoder using its BestPathEnd() and related functions. // It will be instantiated for FST == fst::Fst<fst::StdArc> and fst::GrammarFst. template <typename FST> void ComputeCurrentTraceback(const LatticeFasterOnlineDecoderTpl<FST> &decoder); // Calling this function gets the changes in weight that require us to modify // the stats... the output format is (frame-index, delta-weight). The // num_frames_ready argument is the number of frames available at the input // (or equivalently, output) of the online iVector extractor class, which may // be more than the currently available decoder traceback. How many frames // of weights it outputs depends on how much "num_frames_ready" increased // since last time we called this function, and whether the decoder traceback // changed. Negative delta_weights might occur if frames previously // classified as non-silence become classified as silence if the decoder's // traceback changes. You must call this function with "num_frames_ready" // arguments that only increase, not decrease, with time. You would provide // this output to class OnlineIvectorFeature by calling its function // UpdateFrameWeights with the output. void GetDeltaWeights( int32 num_frames_ready_in, std::vector<std::pair<int32, BaseFloat> > *delta_weights); private: const TransitionModel &trans_model_; const OnlineSilenceWeightingConfig &config_; int32 frame_subsampling_factor_; unordered_set<int32> silence_phones_; struct FrameInfo { // The only reason we need the token pointer is to know far back we have to // trace before the traceback is the same as what we previously traced back. void *token; int32 transition_id; // current_weight is the weight we've previously told the iVector // extractor to use for this frame, if any. It may not equal the // weight we "want" it to use (any difference between the two will // be output when the user calls GetDeltaWeights(). BaseFloat current_weight; FrameInfo(): token(NULL), transition_id(-1), current_weight(0.0) {} }; // gets the frame at which we need to begin our processing in // GetDeltaWeights... normally this is equal to // num_frames_output_and_correct_, but it may be earlier in case // max_state_duration is relevant. int32 GetBeginFrame(); // This contains information about any previously computed traceback; // when the traceback changes we use this variable to compare it with the // previous traceback. // It's indexed at the frame-rate of the decoder (may be different // by 'frame_subsampling_factor_' from the frame-rate of the features. std::vector<FrameInfo> frame_info_; // This records how many frames have been output and that currently reflect // the traceback accurately. It is used to avoid GetDeltaWeights() having to // visit each frame as far back as t = 0, each time it is called. // GetDeltaWeights() sets this to the number of frames that it output, and // ComputeCurrentTraceback() then reduces it to however far it traced back. // However, we may have to go further back in time than this in order to // properly honor the "max-state-duration" config. This, if needed, is done // in GetDeltaWeights() before outputting the delta weights. int32 num_frames_output_and_correct_; }; /// @} End of "addtogroup onlinefeat" } // namespace kaldi #endif // KALDI_ONLINE2_ONLINE_IVECTOR_FEATURE_H_ |