Blame view

src/cudadecoder/decodable-cumatrix.cc 1.98 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
  // cudadecoder/decodable-cumatrix.cc
  /*
   * Copyright (c) 2017, NVIDIA CORPORATION.  All rights reserved.
   * Authors:  Hugo Braun, Justin Luitjens, Ryan Leary
   *
   * Licensed under the Apache License, Version 2.0 (the "License");
   * you may not use this file except in compliance with the License.
   * You may obtain a copy of the License at
   *
   *     http://www.apache.org/licenses/LICENSE-2.0
   *
   * Unless required by applicable law or agreed to in writing, software
   * distributed under the License is distributed on an "AS IS" BASIS,
   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   * See the License for the specific language governing permissions and
   * limitations under the License.
   */
  
  #if HAVE_CUDA == 1
  
  #include "decodable-cumatrix.h"
  
  namespace kaldi {
  namespace cuda_decoder {
  
  DecodableCuMatrixMapped::DecodableCuMatrixMapped(
      const TransitionModel &tm, const CuMatrixBase<BaseFloat> &likes,
      int32 frame_offset)
      : trans_model_(tm), likes_(&likes), frame_offset_(frame_offset) {
    if (likes.NumCols() != tm.NumPdfs())
      KALDI_ERR << "Mismatch, matrix has " << likes.NumCols()
                << " rows but transition-model has " << tm.NumPdfs()
                << " pdf-ids.";
  }
  
  int32 DecodableCuMatrixMapped::NumFramesReady() const {
    return frame_offset_ + likes_->NumRows();
  }
  
  bool DecodableCuMatrixMapped::IsLastFrame(int32 frame) const {
    KALDI_ASSERT(frame < NumFramesReady());
    return (frame == NumFramesReady() - 1);
  }
  
  // Indices are one-based!  This is for compatibility with OpenFst.
  int32 DecodableCuMatrixMapped::NumIndices() const {
    return trans_model_.NumTransitionIds();
  }
  
  // returns cuda pointer to nnet3 output
  BaseFloat *
  DecodableCuMatrixMapped::GetLogLikelihoodsCudaPointer(int32 subsampled_frame) {
    BaseFloat *frame_nnet3_out =
        (BaseFloat *)likes_->Data() +
        (subsampled_frame - frame_offset_) * likes_->Stride();
    return frame_nnet3_out;
  };
  
  }  // end namespace cuda_decoder
  }  // end namespace kaldi
  
  #endif  // HAVE_CUDA == 1