// nnet3/nnet-am-decodable-simple.cc // Copyright 2015 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "nnet3/nnet-am-decodable-simple.h" #include "nnet3/nnet-utils.h" namespace kaldi { namespace nnet3 { DecodableNnetSimple::DecodableNnetSimple( const NnetSimpleComputationOptions &opts, const Nnet &nnet, const VectorBase &priors, const MatrixBase &feats, CachingOptimizingCompiler *compiler, const VectorBase *ivector, const MatrixBase *online_ivectors, int32 online_ivector_period): opts_(opts), nnet_(nnet), output_dim_(nnet_.OutputDim("output")), log_priors_(priors), feats_(feats), ivector_(ivector), online_ivector_feats_(online_ivectors), online_ivector_period_(online_ivector_period), compiler_(*compiler), current_log_post_subsampled_offset_(0) { num_subsampled_frames_ = (feats_.NumRows() + opts_.frame_subsampling_factor - 1) / opts_.frame_subsampling_factor; KALDI_ASSERT(IsSimpleNnet(nnet)); compiler_.GetSimpleNnetContext(&nnet_left_context_, &nnet_right_context_); KALDI_ASSERT(!(ivector != NULL && online_ivectors != NULL)); KALDI_ASSERT(!(online_ivectors != NULL && online_ivector_period <= 0 && "You need to set the --online-ivector-period option!")); log_priors_.ApplyLog(); CheckAndFixConfigs(); } DecodableAmNnetSimple::DecodableAmNnetSimple( const NnetSimpleComputationOptions &opts, const TransitionModel &trans_model, const AmNnetSimple &am_nnet, const MatrixBase &feats, const VectorBase *ivector, const MatrixBase *online_ivectors, int32 online_ivector_period, CachingOptimizingCompiler *compiler): compiler_(am_nnet.GetNnet(), opts.optimize_config, opts.compiler_config), decodable_nnet_(opts, am_nnet.GetNnet(), am_nnet.Priors(), feats, compiler != NULL ? compiler : &compiler_, ivector, online_ivectors, online_ivector_period), trans_model_(trans_model) { // note: we only use compiler_ if the passed-in 'compiler' is NULL. } BaseFloat DecodableAmNnetSimple::LogLikelihood(int32 frame, int32 transition_id) { int32 pdf_id = trans_model_.TransitionIdToPdfFast(transition_id); return decodable_nnet_.GetOutput(frame, pdf_id); } int32 DecodableNnetSimple::GetIvectorDim() const { if (ivector_ != NULL) return ivector_->Dim(); else if (online_ivector_feats_ != NULL) return online_ivector_feats_->NumCols(); else return 0; } void DecodableNnetSimple::EnsureFrameIsComputed(int32 subsampled_frame) { KALDI_ASSERT(subsampled_frame >= 0 && subsampled_frame < num_subsampled_frames_); int32 feature_dim = feats_.NumCols(), ivector_dim = GetIvectorDim(), nnet_input_dim = nnet_.InputDim("input"), nnet_ivector_dim = std::max(0, nnet_.InputDim("ivector")); if (feature_dim != nnet_input_dim) KALDI_ERR << "Neural net expects 'input' features with dimension " << nnet_input_dim << " but you provided " << feature_dim; if (ivector_dim != std::max(0, nnet_.InputDim("ivector"))) KALDI_ERR << "Neural net expects 'ivector' features with dimension " << nnet_ivector_dim << " but you provided " << ivector_dim; int32 current_subsampled_frames_computed = current_log_post_.NumRows(), current_subsampled_offset = current_log_post_subsampled_offset_; KALDI_ASSERT(subsampled_frame < current_subsampled_offset || subsampled_frame >= current_subsampled_offset + current_subsampled_frames_computed); // all subsampled frames pertain to the output of the network, // they are output frames divided by opts_.frame_subsampling_factor. int32 subsampling_factor = opts_.frame_subsampling_factor, subsampled_frames_per_chunk = opts_.frames_per_chunk / subsampling_factor, start_subsampled_frame = subsampled_frame, num_subsampled_frames = std::min(num_subsampled_frames_ - start_subsampled_frame, subsampled_frames_per_chunk), last_subsampled_frame = start_subsampled_frame + num_subsampled_frames - 1; KALDI_ASSERT(num_subsampled_frames > 0); // the output-frame numbers are the subsampled-frame numbers int32 first_output_frame = start_subsampled_frame * subsampling_factor, last_output_frame = last_subsampled_frame * subsampling_factor; KALDI_ASSERT(opts_.extra_left_context >= 0 && opts_.extra_right_context >= 0); int32 extra_left_context = opts_.extra_left_context, extra_right_context = opts_.extra_right_context; if (first_output_frame == 0 && opts_.extra_left_context_initial >= 0) extra_left_context = opts_.extra_left_context_initial; if (last_subsampled_frame == num_subsampled_frames_ - 1 && opts_.extra_right_context_final >= 0) extra_right_context = opts_.extra_right_context_final; int32 left_context = nnet_left_context_ + extra_left_context, right_context = nnet_right_context_ + extra_right_context; int32 first_input_frame = first_output_frame - left_context, last_input_frame = last_output_frame + right_context, num_input_frames = last_input_frame + 1 - first_input_frame; Vector ivector; GetCurrentIvector(first_output_frame, last_output_frame - first_output_frame, &ivector); Matrix input_feats; if (first_input_frame >= 0 && last_input_frame < feats_.NumRows()) { SubMatrix input_feats(feats_.RowRange(first_input_frame, num_input_frames)); DoNnetComputation(first_input_frame, input_feats, ivector, first_output_frame, num_subsampled_frames); } else { Matrix feats_block(num_input_frames, feats_.NumCols()); int32 tot_input_feats = feats_.NumRows(); for (int32 i = 0; i < num_input_frames; i++) { SubVector dest(feats_block, i); int32 t = i + first_input_frame; if (t < 0) t = 0; if (t >= tot_input_feats) t = tot_input_feats - 1; const SubVector src(feats_, t); dest.CopyFromVec(src); } DoNnetComputation(first_input_frame, feats_block, ivector, first_output_frame, num_subsampled_frames); } } // note: in the normal case (with no frame subsampling) you can ignore the // 'subsampled_' in the variable name. void DecodableNnetSimple::GetOutputForFrame(int32 subsampled_frame, VectorBase *output) { if (subsampled_frame < current_log_post_subsampled_offset_ || subsampled_frame >= current_log_post_subsampled_offset_ + current_log_post_.NumRows()) EnsureFrameIsComputed(subsampled_frame); output->CopyFromVec(current_log_post_.Row( subsampled_frame - current_log_post_subsampled_offset_)); } void DecodableNnetSimple::GetCurrentIvector(int32 output_t_start, int32 num_output_frames, Vector *ivector) { if (ivector_ != NULL) { *ivector = *ivector_; return; } else if (online_ivector_feats_ == NULL) { return; } KALDI_ASSERT(online_ivector_period_ > 0); // frame_to_search is the frame that we want to get the most recent iVector // for. We choose a point near the middle of the current window, the concept // being that this is the fairest comparison to nnet2. Obviously we could do // better by always taking the last frame's iVector, but decoding with // 'online' ivectors is only really a mechanism to simulate online operation. int32 frame_to_search = output_t_start + num_output_frames / 2; int32 ivector_frame = frame_to_search / online_ivector_period_; KALDI_ASSERT(ivector_frame >= 0); if (ivector_frame >= online_ivector_feats_->NumRows()) { int32 margin = ivector_frame - (online_ivector_feats_->NumRows() - 1); if (margin * online_ivector_period_ > 50) { // Half a second seems like too long to be explainable as edge effects. KALDI_ERR << "Could not get iVector for frame " << frame_to_search << ", only available till frame " << online_ivector_feats_->NumRows() << " * ivector-period=" << online_ivector_period_ << " (mismatched --online-ivector-period?)"; } ivector_frame = online_ivector_feats_->NumRows() - 1; } *ivector = online_ivector_feats_->Row(ivector_frame); } void DecodableNnetSimple::DoNnetComputation( int32 input_t_start, const MatrixBase &input_feats, const VectorBase &ivector, int32 output_t_start, int32 num_subsampled_frames) { ComputationRequest request; request.need_model_derivative = false; request.store_component_stats = false; bool shift_time = true; // shift the 'input' and 'output' to a consistent // time, to take advantage of caching in the compiler. // An optimization. int32 time_offset = (shift_time ? -output_t_start : 0); // First add the regular features-- named "input". request.inputs.reserve(2); request.inputs.push_back( IoSpecification("input", time_offset + input_t_start, time_offset + input_t_start + input_feats.NumRows())); if (ivector.Dim() != 0) { std::vector indexes; indexes.push_back(Index(0, 0, 0)); request.inputs.push_back(IoSpecification("ivector", indexes)); } IoSpecification output_spec; output_spec.name = "output"; output_spec.has_deriv = false; int32 subsample = opts_.frame_subsampling_factor; output_spec.indexes.resize(num_subsampled_frames); // leave n and x values at 0 (the constructor sets these). for (int32 i = 0; i < num_subsampled_frames; i++) output_spec.indexes[i].t = time_offset + output_t_start + i * subsample; request.outputs.resize(1); request.outputs[0].Swap(&output_spec); std::shared_ptr computation = compiler_.Compile(request); Nnet *nnet_to_update = NULL; // we're not doing any update. NnetComputer computer(opts_.compute_config, *computation, nnet_, nnet_to_update); CuMatrix input_feats_cu(input_feats); computer.AcceptInput("input", &input_feats_cu); CuMatrix ivector_feats_cu; if (ivector.Dim() > 0) { ivector_feats_cu.Resize(1, ivector.Dim()); ivector_feats_cu.Row(0).CopyFromVec(ivector); computer.AcceptInput("ivector", &ivector_feats_cu); } computer.Run(); CuMatrix cu_output; computer.GetOutputDestructive("output", &cu_output); // subtract log-prior (divide by prior) if (log_priors_.Dim() != 0) cu_output.AddVecToRows(-1.0, log_priors_); // apply the acoustic scale cu_output.Scale(opts_.acoustic_scale); current_log_post_.Resize(0, 0); // the following statement just swaps the pointers if we're not using a GPU. cu_output.Swap(¤t_log_post_); current_log_post_subsampled_offset_ = output_t_start / subsample; } void DecodableNnetSimple::CheckAndFixConfigs() { static bool warned_frames_per_chunk = false; int32 nnet_modulus = nnet_.Modulus(); if (opts_.frame_subsampling_factor < 1 || opts_.frames_per_chunk < 1) KALDI_ERR << "--frame-subsampling-factor and --frames-per-chunk must be > 0"; KALDI_ASSERT(nnet_modulus > 0); int32 n = Lcm(opts_.frame_subsampling_factor, nnet_modulus); if (opts_.frames_per_chunk % n != 0) { // round up to the nearest multiple of n. int32 frames_per_chunk = n * ((opts_.frames_per_chunk + n - 1) / n); if (!warned_frames_per_chunk) { warned_frames_per_chunk = true; if (nnet_modulus == 1) { // simpler error message. KALDI_LOG << "Increasing --frames-per-chunk from " << opts_.frames_per_chunk << " to " << frames_per_chunk << " to make it a multiple of " << "--frame-subsampling-factor=" << opts_.frame_subsampling_factor; } else { KALDI_LOG << "Increasing --frames-per-chunk from " << opts_.frames_per_chunk << " to " << frames_per_chunk << " due to " << "--frame-subsampling-factor=" << opts_.frame_subsampling_factor << " and " << "nnet shift-invariance modulus = " << nnet_modulus; } } opts_.frames_per_chunk = frames_per_chunk; } } DecodableAmNnetSimpleParallel::DecodableAmNnetSimpleParallel( const NnetSimpleComputationOptions &opts, const TransitionModel &trans_model, const AmNnetSimple &am_nnet, const MatrixBase &feats, const VectorBase *ivector, const MatrixBase *online_ivectors, int32 online_ivector_period): compiler_(am_nnet.GetNnet(), opts.optimize_config, opts.compiler_config), trans_model_(trans_model), feats_copy_(NULL), ivector_copy_(NULL), online_ivectors_copy_(NULL), decodable_nnet_(NULL) { try { feats_copy_ = new Matrix(feats); if (ivector != NULL) ivector_copy_ = new Vector(*ivector); if (online_ivectors != NULL) online_ivectors_copy_ = new Matrix(*online_ivectors); decodable_nnet_ = new DecodableNnetSimple(opts, am_nnet.GetNnet(), am_nnet.Priors(), *feats_copy_, &compiler_, ivector_copy_, online_ivectors_copy_, online_ivector_period); } catch (...) { DeletePointers(); KALDI_ERR << "Error occurred in constructor (see above)"; } } void DecodableAmNnetSimpleParallel::DeletePointers() { // delete[] does nothing for null pointers, so we have no checks. delete decodable_nnet_; decodable_nnet_ = NULL; delete feats_copy_; feats_copy_ = NULL; delete ivector_copy_; ivector_copy_ = NULL; delete online_ivectors_copy_; online_ivectors_copy_ = NULL; } BaseFloat DecodableAmNnetSimpleParallel::LogLikelihood(int32 frame, int32 transition_id) { int32 pdf_id = trans_model_.TransitionIdToPdfFast(transition_id); return decodable_nnet_->GetOutput(frame, pdf_id); } } // namespace nnet3 } // namespace kaldi