rnnlm-train.cc
7.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
// rnnlmbin/rnnlm-train.cc
// Copyright 2015-2017 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "rnnlm/rnnlm-training.h"
#include "rnnlm/rnnlm-example-utils.h"
#include "nnet3/nnet-utils.h"
#include "cudamatrix/cu-allocator.h"
int main(int argc, char *argv[]) {
try {
using namespace kaldi;
using namespace kaldi::rnnlm;
typedef kaldi::int32 int32;
typedef kaldi::int64 int64;
// rnnlm_rxfilename must be supplied, via --read-rnnlm option.
std::string rnnlm_rxfilename;
// For now, rnnlm_wxfilename must be supplied (later we could make it possible
// to train the embedding matrix without training the RNNLM itself, if there
// is a need).
std::string rnnlm_wxfilename;
// embedding_rxfilename must be supplied, via --read-embedding option.
std::string embedding_rxfilename;
std::string embedding_wxfilename;
std::string word_features_rxfilename;
// binary mode for writing output.
bool binary = true;
RnnlmCoreTrainerOptions core_config;
RnnlmEmbeddingTrainerOptions embedding_config;
RnnlmObjectiveOptions objective_config;
const char *usage =
"Train nnet3-based RNNLM language model (reads minibatches prepared\n"
"by rnnlm-get-egs). Supports various modes depending which parameters\n"
"we are training.\n"
"Usage:\n"
" rnnlm-train [options] <egs-rspecifier>\n"
"e.g.:\n"
" rnnlm-get-egs ... ark:- | \\\n"
" rnnlm-train --read-rnnlm=foo/0.raw --write-rnnlm=foo/1.raw --read-embedding=foo/0.embedding \\\n"
" --write-embedding=foo/1.embedding --read-sparse-word-features=foo/word_feats.txt ark:-\n"
"See also: rnnlm-get-egs\n";
std::string use_gpu = "yes";
ParseOptions po(usage);
po.Register("use-gpu", &use_gpu,
"yes|no|optional|wait, only has effect if compiled with CUDA");
po.Register("read-rnnlm", &rnnlm_rxfilename,
"Read RNNLM from this location (e.g. 0.raw). Must be supplied.");
po.Register("write-rnnlm", &rnnlm_wxfilename,
"Write RNNLM to this location (e.g. 1.raw)."
"If not supplied, the core RNNLM is not trained "
"(but other parts of the model might be.");
po.Register("read-embedding", &embedding_rxfilename,
"Location to read dense (feature or word) embedding matrix, "
"of dimension (num-words or num-features) by (embedding-dim).");
po.Register("write-embedding", &embedding_wxfilename,
"Location to write embedding matrix (c.f. --read-embedding). "
"If not provided, the embedding will not be trained.");
po.Register("read-sparse-word-features", &word_features_rxfilename,
"Location to read sparse word-feature matrix, e.g. "
"word_feats.txt. Format is lines like: '1 30 1.0 516 1.0':"
"starting with word-index, then a list of pairs "
"(feature-index, value) only including nonzero features. "
"This will usually be determined in an ad-hoc way based on "
"letters and other hand-built features; it's not trainable."
" If present, the embedding matrix read via --read-embedding "
"will be interpreted as a feature-embedding matrix.");
po.Register("binary", &binary,
"If true, write outputs in binary form.");
objective_config.Register(&po);
RegisterCuAllocatorOptions(&po);
// register the core RNNLM training options options with the prefix "rnnlm",
// so they will appear as --rnnlm.max-change and the like. This is done
// with a prefix because later we may add a neural net to transform the word
// embedding, and it would have options that would have a name conflict with
// some of these options.
ParseOptions core_opts("rnnlm", &po);
core_config.Register(&core_opts);
// ... and register the embedding options with the prefix "embedding".
ParseOptions embedding_opts("embedding", &po);
embedding_config.Register(&embedding_opts);
po.Read(argc, argv);
if (po.NumArgs() != 1) {
po.PrintUsage();
exit(1);
}
if (rnnlm_rxfilename == "" ||
rnnlm_wxfilename == "" ||
embedding_rxfilename == "") {
KALDI_WARN << "--read-rnnlm, --write-rnnlm and --read-embedding "
"options are required.";
po.PrintUsage();
exit(1);
}
std::string examples_rspecifier = po.GetArg(1);
#if HAVE_CUDA==1
CuDevice::Instantiate().SelectGpuId(use_gpu);
CuDevice::Instantiate().AllowMultithreading();
#endif
kaldi::nnet3::Nnet rnnlm;
ReadKaldiObject(rnnlm_rxfilename, &rnnlm);
if (!IsSimpleNnet(rnnlm))
KALDI_ERR << "Input RNNLM in " << rnnlm_rxfilename
<< " is not the type of neural net we were looking for; "
"failed IsSimpleNnet().";
CuMatrix<BaseFloat> embedding_mat;
ReadKaldiObject(embedding_rxfilename, &embedding_mat);
CuSparseMatrix<BaseFloat> word_feature_mat;
if (word_features_rxfilename != "") {
// binary mode is not supported here; it's a text format.
Input input(word_features_rxfilename);
int32 feature_dim = embedding_mat.NumRows();
SparseMatrix<BaseFloat> cpu_word_feature_mat;
ReadSparseWordFeatures(input.Stream(), feature_dim,
&cpu_word_feature_mat);
word_feature_mat.Swap(&cpu_word_feature_mat); // copy to GPU, if we have
// one.
}
{
bool train_embedding = (embedding_wxfilename != "");
RnnlmTrainer trainer(
train_embedding, core_config, embedding_config, objective_config,
(word_features_rxfilename != "" ? &word_feature_mat : NULL),
&embedding_mat, &rnnlm);
SequentialRnnlmExampleReader example_reader(examples_rspecifier);
for (; !example_reader.Done(); example_reader.Next())
trainer.Train(&(example_reader.Value()));
if (trainer.NumMinibatchesProcessed() == 0)
KALDI_ERR << "There was no data to train on.";
// The destructor of 'trainer' trains on the last minibatch
// and writes out anything we need to write out.
}
WriteKaldiObject(rnnlm, rnnlm_wxfilename, binary);
KALDI_LOG << "Wrote RNNLM to "
<< PrintableWxfilename(rnnlm_wxfilename);
if (embedding_wxfilename != "") {
WriteKaldiObject(embedding_mat, embedding_wxfilename, binary);
KALDI_LOG << "Wrote embedding matrix to "
<< PrintableWxfilename(embedding_wxfilename);
}
#if HAVE_CUDA==1
CuDevice::Instantiate().PrintProfile();
#endif
return 0;
} catch(const std::exception &e) {
std::cerr << e.what() << '\n';
return -1;
}
}