Blame view
src/cudadecoder/decodable-cumatrix.cc
1.98 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
// cudadecoder/decodable-cumatrix.cc /* * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. * Authors: Hugo Braun, Justin Luitjens, Ryan Leary * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #if HAVE_CUDA == 1 #include "decodable-cumatrix.h" namespace kaldi { namespace cuda_decoder { DecodableCuMatrixMapped::DecodableCuMatrixMapped( const TransitionModel &tm, const CuMatrixBase<BaseFloat> &likes, int32 frame_offset) : trans_model_(tm), likes_(&likes), frame_offset_(frame_offset) { if (likes.NumCols() != tm.NumPdfs()) KALDI_ERR << "Mismatch, matrix has " << likes.NumCols() << " rows but transition-model has " << tm.NumPdfs() << " pdf-ids."; } int32 DecodableCuMatrixMapped::NumFramesReady() const { return frame_offset_ + likes_->NumRows(); } bool DecodableCuMatrixMapped::IsLastFrame(int32 frame) const { KALDI_ASSERT(frame < NumFramesReady()); return (frame == NumFramesReady() - 1); } // Indices are one-based! This is for compatibility with OpenFst. int32 DecodableCuMatrixMapped::NumIndices() const { return trans_model_.NumTransitionIds(); } // returns cuda pointer to nnet3 output BaseFloat * DecodableCuMatrixMapped::GetLogLikelihoodsCudaPointer(int32 subsampled_frame) { BaseFloat *frame_nnet3_out = (BaseFloat *)likes_->Data() + (subsampled_frame - frame_offset_) * likes_->Stride(); return frame_nnet3_out; }; } // end namespace cuda_decoder } // end namespace kaldi #endif // HAVE_CUDA == 1 |