Blame view

src/decoder/decodable-matrix.cc 3.6 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
  // decoder/decodable-matrix.cc
  
  // Copyright    2018 Johns Hopkins University (author: Daniel Povey)
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  #include "decoder/decodable-matrix.h"
  
  namespace kaldi {
  
  DecodableMatrixMapped::DecodableMatrixMapped(
      const TransitionModel &tm,
      const MatrixBase<BaseFloat> &likes,
      int32 frame_offset):
      trans_model_(tm), likes_(&likes), likes_to_delete_(NULL),
      frame_offset_(frame_offset) {
    stride_ = likes.Stride();
    raw_data_ = likes.Data() - (stride_ * frame_offset);
  
    if (likes.NumCols() != tm.NumPdfs())
      KALDI_ERR << "Mismatch, matrix has "
                << likes.NumCols() << " rows but transition-model has "
                << tm.NumPdfs() << " pdf-ids.";
  }
  
  DecodableMatrixMapped::DecodableMatrixMapped(
      const TransitionModel &tm, const Matrix<BaseFloat> *likes,
      int32 frame_offset):
      trans_model_(tm), likes_(likes), likes_to_delete_(likes),
      frame_offset_(frame_offset) {
    stride_ = likes->Stride();
    raw_data_ = likes->Data() - (stride_ * frame_offset_);
    if (likes->NumCols() != tm.NumPdfs())
      KALDI_ERR << "Mismatch, matrix has "
                << likes->NumCols() << " rows but transition-model has "
                << tm.NumPdfs() << " pdf-ids.";
  }
  
  
  BaseFloat DecodableMatrixMapped::LogLikelihood(int32 frame, int32 tid) {
    int32 pdf_id = trans_model_.TransitionIdToPdfFast(tid);
  #ifdef KALDI_PARANOID
    return (*likes_)(frame - frame_offset_, pdf_id);
  #else
    return raw_data_[frame * stride_ + pdf_id];
  #endif
  }
  
  int32 DecodableMatrixMapped::NumFramesReady() const {
    return frame_offset_ + likes_->NumRows();
  }
  
  bool DecodableMatrixMapped::IsLastFrame(int32 frame) const {
    KALDI_ASSERT(frame < NumFramesReady());
    return (frame == NumFramesReady() - 1);
  }
  
  // Indices are one-based!  This is for compatibility with OpenFst.
  int32 DecodableMatrixMapped::NumIndices() const {
    return trans_model_.NumTransitionIds();
  }
  
  DecodableMatrixMapped::~DecodableMatrixMapped() {
    delete likes_to_delete_;
  }
  
  
  void DecodableMatrixMappedOffset::AcceptLoglikes(
      Matrix<BaseFloat> *loglikes, int32 frames_to_discard) {
    if (loglikes->NumRows() == 0) return;
    KALDI_ASSERT(loglikes->NumCols() == trans_model_.NumPdfs());
    KALDI_ASSERT(frames_to_discard <= loglikes_.NumRows() &&
                 frames_to_discard >= 0);
    if (frames_to_discard == loglikes_.NumRows()) {
      loglikes_.Swap(loglikes);
      loglikes->Resize(0, 0);
    } else {
      int32 old_rows_kept = loglikes_.NumRows() - frames_to_discard,
          new_num_rows = old_rows_kept + loglikes->NumRows();
      Matrix<BaseFloat> new_loglikes(new_num_rows, loglikes->NumCols());
      new_loglikes.RowRange(0, old_rows_kept).CopyFromMat(
          loglikes_.RowRange(frames_to_discard, old_rows_kept));
      new_loglikes.RowRange(old_rows_kept, loglikes->NumRows()).CopyFromMat(
          *loglikes);
      loglikes_.Swap(&new_loglikes);
    }
    frame_offset_ += frames_to_discard;
    stride_ = loglikes_.Stride();
    raw_data_ = loglikes_.Data() - (frame_offset_ * stride_);
  }
  
  
  
  } // end namespace kaldi.