rnnlm-compute-state.cc
6.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
// src/rnnlm/rnnlm-compute-state.cc
// Copyright 2017 Johns Hopkins University (author: Daniel Povey)
// 2017 Yiming Wang
// 2017 Hainan Xu
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "rnnlm/rnnlm-compute-state.h"
#include "nnet3/nnet-utils.h"
#include "nnet3/nnet-compile-looped.h"
namespace kaldi {
namespace rnnlm {
RnnlmComputeStateInfo::RnnlmComputeStateInfo(
const RnnlmComputeStateComputationOptions &opts,
const kaldi::nnet3::Nnet &rnnlm,
const CuMatrix<BaseFloat> &word_embedding_mat):
opts(opts), rnnlm(rnnlm), word_embedding_mat(word_embedding_mat) {
KALDI_ASSERT(IsSimpleNnet(rnnlm));
int32 left_context, right_context;
ComputeSimpleNnetContext(rnnlm, &left_context, &right_context);
if (0 != left_context || 0 != right_context) {
KALDI_ERR << "Non-zero left or right context. Please check your script";
}
int32 frame_subsampling_factor = 1;
int32 embedding_dim = word_embedding_mat.NumCols();
if (embedding_dim != rnnlm.OutputDim("output")) {
KALDI_ERR << "Embedding file and nnet have different embedding sizes. ";
}
if (opts.bos_index <= 0 || opts.bos_index >= word_embedding_mat.NumRows()) {
KALDI_ERR << "--bos-symbol option isn't set correctly.";
}
if (opts.eos_index <= 0 || opts.eos_index >= word_embedding_mat.NumRows()) {
KALDI_ERR << "--eos-symbol option isn't set correctly.";
}
nnet3::ComputationRequest request1, request2, request3;
CreateLoopedComputationRequestSimple(rnnlm,
1, // num_frames
frame_subsampling_factor,
1, // ivector_period = 1
0, // extra_left_context_initial == 0
0, // extra_right_context == 0
1, // num_sequnces == 1
&request1, &request2, &request3);
CompileLooped(rnnlm, opts.optimize_config, request1, request2,
request3, &computation);
computation.ComputeCudaIndexes();
if (GetVerboseLevel() >= 3) {
KALDI_VLOG(3) << "Computation is:";
computation.Print(std::cerr, rnnlm);
}
}
RnnlmComputeState::RnnlmComputeState(const RnnlmComputeStateInfo &info,
int32 bos_index) :
info_(info),
computer_(info_.opts.compute_config, info_.computation,
info_.rnnlm, NULL), // NULL is 'nnet_to_update'
previous_word_(-1),
normalization_factor_(0.0) {
AddWord(bos_index);
}
RnnlmComputeState::RnnlmComputeState(const RnnlmComputeState &other):
info_(other.info_), computer_(other.computer_),
previous_word_(other.previous_word_),
normalization_factor_(other.normalization_factor_)
{}
RnnlmComputeState* RnnlmComputeState::GetSuccessorState(int32 next_word) const {
RnnlmComputeState *ans = new RnnlmComputeState(*this);
ans->AddWord(next_word);
return ans;
}
void RnnlmComputeState::AddWord(int32 word_index) {
KALDI_ASSERT(word_index > 0 && word_index < info_.word_embedding_mat.NumRows());
previous_word_ = word_index;
AdvanceChunk();
const CuMatrix<BaseFloat> &word_embedding_mat = info_.word_embedding_mat;
if (info_.opts.normalize_probs) {
CuVector<BaseFloat> log_probs(info_.word_embedding_mat.NumRows());
log_probs.AddMatVec(1.0, word_embedding_mat, kNoTrans,
predicted_word_embedding_->Row(0), 0.0);
log_probs.ApplyExp();
// We excluding the <eps> symbol which is always 0.
normalization_factor_ = log(log_probs.Range(1, log_probs.Dim() - 1).Sum());
}
}
BaseFloat RnnlmComputeState::LogProbOfWord(int32 word_index) const {
const CuMatrix<BaseFloat> &word_embedding_mat = info_.word_embedding_mat;
BaseFloat log_prob = VecVec(predicted_word_embedding_->Row(0),
word_embedding_mat.Row(word_index));
// Even without explicit normalization, the log-probs will be close to
// correctly normalized due to the way the model was trained.
if (info_.opts.normalize_probs) {
log_prob -= normalization_factor_;
}
return log_prob;
}
void RnnlmComputeState::GetLogProbOfWords(CuMatrixBase<BaseFloat> *output) const {
const CuMatrix<BaseFloat> &word_embedding_mat = info_.word_embedding_mat;
KALDI_ASSERT(output->NumRows() == 1
&& output->NumCols() == word_embedding_mat.NumCols());
output->Row(0).AddMatVec(1.0, word_embedding_mat, kNoTrans,
predicted_word_embedding_->Row(0), 0.0);
// Even without explicit normalization, the log-probs will be close to
// correctly normalized due to the way the model was trained.
if (info_.opts.normalize_probs) {
output->Add(normalization_factor_);
}
// making sure <eps> has almost 0 prob
output->ColRange(0, 1).Set(-99.0);
}
void RnnlmComputeState::AdvanceChunk() {
CuMatrix<BaseFloat> input_embeddings(1, info_.word_embedding_mat.NumCols());
input_embeddings.Row(0).AddVec(1.0,
info_.word_embedding_mat.Row(previous_word_));
computer_.AcceptInput("input", &input_embeddings);
computer_.Run();
{
// Note: here GetOutput() is used instead of GetOutputDestructive(), since
// here we have recurrence that goes directly from the output, and the call
// to GetOutputDestructive() would cause a crash on the next chunk.
const CuMatrixBase<BaseFloat> &output(computer_.GetOutput("output"));
predicted_word_embedding_ = &output;
}
}
} // namespace rnnlm
} // namespace kaldi