rnnlm-lattice-rescoring.cc
3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// rnnlm/rnnlm-lattice-rescoring.cc
// Copyright 2017 Johns Hopkins University (author: Daniel Povey)
// 2017 Yiming Wang
// 2017 Hainan Xu
//
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <utility>
#include "rnnlm/rnnlm-lattice-rescoring.h"
#include "util/stl-utils.h"
#include "util/text-utils.h"
namespace kaldi {
namespace rnnlm {
KaldiRnnlmDeterministicFst::~KaldiRnnlmDeterministicFst() {
int32 size = state_to_rnnlm_state_.size();
for (int32 i = 0; i < size; i++)
delete state_to_rnnlm_state_[i];
state_to_rnnlm_state_.resize(0);
state_to_wseq_.resize(0);
wseq_to_state_.clear();
}
void KaldiRnnlmDeterministicFst::Clear() {
// This function is similar to the destructor but we retain the 0-th entries
// in each map which corresponds to the <bos> state.
int32 size = state_to_rnnlm_state_.size();
for (int32 i = 1; i < size; i++)
delete state_to_rnnlm_state_[i];
state_to_rnnlm_state_.resize(1);
state_to_wseq_.resize(1);
wseq_to_state_.clear();
wseq_to_state_[state_to_wseq_[0]] = 0;
}
KaldiRnnlmDeterministicFst::KaldiRnnlmDeterministicFst(int32 max_ngram_order,
const RnnlmComputeStateInfo &info) {
max_ngram_order_ = max_ngram_order;
bos_index_ = info.opts.bos_index;
eos_index_ = info.opts.eos_index;
std::vector<Label> bos_seq;
bos_seq.push_back(bos_index_);
state_to_wseq_.push_back(bos_seq);
RnnlmComputeState *decodable_rnnlm = new RnnlmComputeState(info, bos_index_);
wseq_to_state_[bos_seq] = 0;
start_state_ = 0;
state_to_rnnlm_state_.push_back(decodable_rnnlm);
}
fst::StdArc::Weight KaldiRnnlmDeterministicFst::Final(StateId s) {
/// At this point, we have created the state.
KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());
RnnlmComputeState* rnn = state_to_rnnlm_state_[s];
return Weight(-rnn->LogProbOfWord(eos_index_));
}
bool KaldiRnnlmDeterministicFst::GetArc(StateId s, Label ilabel,
fst::StdArc *oarc) {
/// At this point, we have created the state.
KALDI_ASSERT(static_cast<size_t>(s) < state_to_wseq_.size());
std::vector<Label> word_seq = state_to_wseq_[s];
const RnnlmComputeState* rnnlm = state_to_rnnlm_state_[s];
BaseFloat logprob = rnnlm->LogProbOfWord(ilabel);
word_seq.push_back(ilabel);
if (max_ngram_order_ > 0) {
while (word_seq.size() >= max_ngram_order_) {
/// History state has at most <max_ngram_order_> - 1 words in the state.
word_seq.erase(word_seq.begin(), word_seq.begin() + 1);
}
}
std::pair<const std::vector<Label>, StateId> wseq_state_pair(
word_seq, static_cast<Label>(state_to_wseq_.size()));
// Attemps to insert the current <wseq_state_pair>. If the pair already exists
// then it returns false.
typedef MapType::iterator IterType;
std::pair<IterType, bool> result = wseq_to_state_.insert(wseq_state_pair);
// If the pair was just inserted, then also add it to state_to_* structures.
if (result.second == true) {
RnnlmComputeState *rnnlm2 = rnnlm->GetSuccessorState(ilabel);
state_to_wseq_.push_back(word_seq);
state_to_rnnlm_state_.push_back(rnnlm2);
}
// Creates the arc.
oarc->ilabel = ilabel;
oarc->olabel = ilabel;
oarc->nextstate = result.first->second;
oarc->weight = Weight(-logprob);
return true;
}
} // namespace rnnlm
} // namespace kaldi