rnnlm-compute-state.h
5.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// src/rnnlm/rnnlm-compute-state.h
// Copyright 2017 Johns Hopkins University (author: Daniel Povey)
// 2017 Yiming Wang
// 2017 Hainan Xu
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_RNNLM_COMPUTE_STATE_H_
#define KALDI_RNNLM_COMPUTE_STATE_H_
#include <vector>
#include "base/kaldi-common.h"
#include "nnet3/nnet-optimize.h"
#include "nnet3/nnet-compute.h"
#include "nnet3/am-nnet-simple.h"
#include "rnnlm/rnnlm-core-compute.h"
namespace kaldi {
namespace rnnlm {
struct RnnlmComputeStateComputationOptions {
bool debug_computation;
bool normalize_probs;
// We need this when we initialize the RnnlmComputeState and pass the BOS history.
int32 bos_index;
// We need this to compute the Final() cost of a state.
int32 eos_index;
// This is not needed for computation; included only for ease of scripting.
int32 brk_index;
nnet3::NnetOptimizeOptions optimize_config;
nnet3::NnetComputeOptions compute_config;
RnnlmComputeStateComputationOptions():
debug_computation(false),
normalize_probs(false),
bos_index(-1),
eos_index(-1),
brk_index(-1)
{ }
void Register(OptionsItf *opts) {
opts->Register("debug-computation", &debug_computation, "If true, turn on "
"debug for the actual computation (very verbose!)");
opts->Register("normalize-probs", &normalize_probs, "If true, word "
"probabilities will be correctly normalized (otherwise the sum-to-one "
"normalization is approximate)");
opts->Register("bos-symbol", &bos_index, "Index in wordlist representing "
"the begin-of-sentence symbol");
opts->Register("eos-symbol", &eos_index, "Index in wordlist representing "
"the end-of-sentence symbol");
opts->Register("brk-symbol", &brk_index, "Index in wordlist representing "
"the break symbol. It is not needed in the computation "
"and we are including it for ease of scripting");
// Register the optimization options with the prefix "optimization".
ParseOptions optimization_opts("optimization", opts);
optimize_config.Register(&optimization_opts);
// Register the compute options with the prefix "computation".
ParseOptions compute_opts("computation", opts);
compute_config.Register(&compute_opts);
}
};
/*
This class const references to the word-embedding, nnet3 part of rnnlm and
the RnnlmComputeStateComputationOptions. It handles the computation of the nnet3
object
*/
class RnnlmComputeStateInfo {
public:
RnnlmComputeStateInfo(
const RnnlmComputeStateComputationOptions &opts,
const kaldi::nnet3::Nnet &rnnlm,
const CuMatrix<BaseFloat> &word_embedding_mat);
const RnnlmComputeStateComputationOptions &opts;
const kaldi::nnet3::Nnet &rnnlm;
const CuMatrix<BaseFloat> &word_embedding_mat;
// The compiled, 'looped' computation.
nnet3::NnetComputation computation;
};
/*
This class handles the neural net computation; it's mostly accessed
via other wrapper classes.
Each time this class takes a new word and advance the NNET computation by
one step, and works out log-prob of words to be used in lattice rescoring. */
class RnnlmComputeState {
public:
/// We compile the computation and generate the state after the BOS history.
RnnlmComputeState(const RnnlmComputeStateInfo &info, int32 bos_index);
RnnlmComputeState(const RnnlmComputeState &other);
/// Generate another state by passing the next-word.
/// The pointer is owned by the caller.
RnnlmComputeState* GetSuccessorState(int32 next_word) const;
/// Return the log-prob that the model predicts for the provided word-index,
/// given the previous history determined by the sequence of calls to AddWord()
/// (implicitly starting with the BOS symbol).
BaseFloat LogProbOfWord(int32 word_index) const;
// This function computes logprobs of all words and set it to output Matrix
// Note: (*output)(0, 0) corresponds to <eps> symbol and it should NEVER be
// used in any computation by the caller. To avoid causing unexpected issues,
// we here set it to a very small number
void GetLogProbOfWords(CuMatrixBase<BaseFloat>* output) const;
/// Advance the state of the RNNLM by appending this word to the word sequence.
void AddWord(int32 word_index);
private:
/// This function does the computation for the next chunk.
void AdvanceChunk();
const RnnlmComputeStateInfo &info_;
nnet3::NnetComputer computer_;
int32 previous_word_;
// This is the log of the sum of the exp'ed values in the output.
// Only used if config_.normalize_probs is set to be true.
BaseFloat normalization_factor_;
// This points to the matrix returned by GetOutput() on the Nnet object.
// This pointer is not owned by this class.
const CuMatrixBase<BaseFloat> *predicted_word_embedding_;
};
} // namespace rnnlm
} // namespace kaldi
#endif // KALDI_RNNLM_COMPUTE_STATE_H_