chain.dox
23.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
// doc/chain.dox
// Copyright 2015 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
namespace kaldi {
/**
\page chain 'Chain' models
\section chain_intro Introduction to 'chain' models
The 'chain' models are a type of DNN-HMM model, implemented using \ref dnn3 "nnet3", and differ from the
conventional model in various ways; you can think of them as a different
design point in the space of acoustic models.
- We use a 3 times smaller frame rate at the output of the neural net,
This significantly reduces the amount of computation required in
test time, making real-time decoding much easier.
- The models are trained right from the start with a sequence-level
objective function-- namely, the log probability of the correct sequence. It is
essentially MMI implemented without lattices on the GPU, by doing a full
forward-backward on a decoding graph derived from a phone n-gram language
model.
- Because of the reduced frame rate, we need to use unconventional
HMM topologies (allowing the traversal of the HMM in one state).
- We use fixed transition probabilities in the HMM, and don't train
them (we may decide train them in future; but for the most part the neural-net
output probabilities can do the same job as the transition probabilities,
depending on the topology).
- Currently, only nnet3 DNNs are supported (see \ref dnn3), and
online decoding has not yet been implemented (we're aiming for April to June 2016).
- Currently the results are a bit better then those of conventional
DNN-HMMs (about 5\% relative better), but the system is about 3 times
faster to decode; training time is probably a bit faster too, but
we haven't compared it exactly.
\section chain_scripts Where to find scripts for the 'chain' models
The current best scripts for the 'chain' models can be found in the
Switchboard setup in egs/swbd/s5c; the script local/chain/run_tdnn_2o.sh is
the current best one. This is currently available in the 'chain' branch of
the official github repository (https://github.com/kaldi-asr/kaldi.git) and
eventually will be merged to the master.
This script uses TDNNs as the neural net (we've been doing the development
with TDNNs because they are easier to tune then LSTMs), and gives a better WER
WER than the baseline TDNN: 11.4\%, versus 12.1\% for the best TDNN baseline
(on the Switchboard-only portion of eval2000).
\section chain_model The chain model
The chain model itself is no different from a conventional DNN-HMM, used with
a (currently) 3-fold reduced frame rate at the output of the DNN. The input
features of the DNN are at the original frame rate of 100 per second; this makes
sense because all the neural nets we are currently using (LSTMs, TDNNs) have some kind
of recurrent connections or splicing inside them, i.e. they are not purely feedforward
nets.
The difference from a normal model is the objective function used to train it:
instead of a frame-level objective, we use the log-probability of the correct
phone sequence as the objective function. The training process is quite
similar in principle to MMI training, in which we compute numerator and
denominator 'occupation probabilities' and the difference between the two is
used in the derivative computation. There is no need to normalize the DNN
outputs to sum to one on each frame any more- such normalization makes no difference.
Because of the reduced frame rate (one frame every 30 ms), we need to use a
modified HMM topology. We would like the HMM to be traversable in one
transition (as opposed to the 3 transitions of a model mat the normal frame
rate). The currently favored topology has a state that can only occur once,
and then another state that can appear zero or more times. The state-clustering
is obtained using the same procedure as for GMM-based models, although
of course with a different topology (we convert the alignments to the new topology
and frame-rate).
\section chain_training The training procedure for 'chain' models
The training procedure for chain models is a lattice-free version of
MMI, where the denominator state posteriors are obtained by the
forward-backward algorithm over a HMM formed from a phone-level decoding graph,
and the numerator state posteriors are obtained by a similar forward-backward
algorithm but limited to sequences corresponding to the transcript.
For each output index of the neural net (i.e. for each pdf-id), we
compute a derivative of of the form (numerator occupation probability -
denominator occupation probability), and these are propagated back to the
network.
\subsection chain_training_denominator The denominator FST
For the denominator part of the computation we do forward-backward over a HMM.
Actually, because we represent it as a finite state acceptor, the labels
(pdf-ids) are associated with the arcs and not the states, so it's not really a
HMM in the normal formulation, but it's easier think of it as a HMM because
we use the forward-backward algorithm to get posteriors.
In the code and scripts we refer to it as the 'denominator FST'.
\subsubsection chain_training_denominator_phone_lm Phone language model for the denominator FST
The first stage in constructing the denominator FST is to create a phone
language model. This language model is learned from the training-data phone
alignments. This is an un-smoothed language model, meaning that we never do
backoff to lower order n-grams. However, some language-model states are
removed entirely, so transitions to those states go instead to the lower-order
n-gram's state. The reason we avoid smoothing is to reduce the number of
arcs that there will be in the compiled graph after phonetic context expansion.
The configuration that we settled on is to estimate a 4-gram language model,
and to never prune LM states below trigram (so we always maintain at least a
2-phone history). On top of the number of states dictated by the no-prune
trigram rule, we have a specifiable number (e.g. 2000) of 4-gram language
model states which are to be retained (all the rest are identified with the
corresponding trigram state), and the ones we choose to retain are determined
in a way that maximizes the training-data likelihood. All probabilities are
estimated to maximize the training-data likelihood. The reason not to prune
the trigrams is that any sparsity of which trigrams are allowed, will tend to
minimize the size of the compiled graph. Note that if our phone LM was just a
simple phone loop (i.e. a unigram), it would get expanded to triphones anyway
due to phonetic context effects, but it would have arcs for all possible
trigrams in it. So any sparsity we get from using the un-pruned trigram model
is a bonus. Empirically, an un-smoothed trigram LM is what expands to the
smallest possible FST; and pruning some of the trigrams, while it increases
the size of the compiled FST, results in little or no WER improvement (at
least on 300 hours of data expanded 3-fold with speed perturbation; on less
data it might help).
On the Switchboard setups the phone-LM perplexities for the various models we
tried were in the range 5 to 7; the phone-LM perplexity with our chosen
configuration (4-gram, pruned to trigram for all but 2000 states) was about 6.
It was not the case that lower phone-LM perplexity always led to better WER
of the trained system; as for conventional (word-based) MMI training, an
intermediate strength of language model seemed to work best.
\subsubsection chain_training_denominator_compilation Compilation of the denominator FST
The phone language model described in the previous section is expanded into a
FST with 'pdf-ids' as the arcs, in a process that mirrors the process of
decoding-graph compilation in normal Kaldi decoding (see \ref
graph_recipe_test), except that there is no lexicon is involved, and at the
end we convert the transition-ids to pdf-ids.
One difference lies in how we minimize the size of the graph. The normal
recipe involves determinization and minimization. We were not able to
reduce the size of the graph using this procedure, or variants of it with
disambiguation symbols. Instead, our graph-minimization process can be described
compactly as follows: "Repeat 3 times: push, minimize, reverse; push, minimize reverse.".
'push' refers to weight-pushing; 'reverse' refers to reversing the directions of arcs, and
swapping initial and final states.
\subsubsection chain_training_denominator_normalization Initial and final probabilities, and 'normalization FST'
The graph-creation process mentioned above naturally gives us an initial
state, and final probabilities for each state; but these are not the ones we
use in the forward-backward. The reason is that these probabilities are
applicable to utterance boundaries, but we train on split-up chunks of
utterance of a fixed length (e.g. 1.5 seconds). Constraining the HMM at these
arbitrarily chosen cut points to the initial and final states is not
appropriate. Instead, we use initial probabilities derived from 'running the HMM' for
a fixed number of iterations and averaging the probabilities; and final probabilities
equal to 1.0 for each state. We have a justification for this but don't have time to
explain it right now. In the denominator forward-backward process we apply these initial and
final probabilities to the initial and final frame as part of the computation. However, we also
write out a version of the denominator FST that has these initial and final probabilities, and we refer to
this as the 'normalization FST.' (The initial probabilities are emulated using epsilon arcs, because
FSTs do not support initial probabilities). This 'normalization FST' will be used to add probabilities to the
numerator FSTs in a way that we'll describe later.
\subsection chain_training_numerator Numerator FSTs
As part of our preparation for the training process we produce something
called a 'numerator FST' for each utterance. The numerator FST encodes the
supervision transcript, and also encodes an alignment of that transcript
(i.e. it forces similarity to a reference alignment obtained from a baseline
system), but it allows a little 'wiggle room' to vary from that reference.
By default we allow a phone to occur 0.05 seconds before or after its
begin and end position respectively, in the lattice alignment.
Incorporating the alignment information is important because of the way we
train not on entire utterances but on split-up fixed-length pieces of
utterances (which, in turn, is important for GPU-based training): splitting up
the utterance into pieces if we know where the transcript aligns.
Instead of enforcing a particular pronunciation of the training data, we use as
our reference a lattice of alternative pronunciations of the training data,
generated by a lattice-generating decoding procedure using an
utterance-specific graph as the decoding graph. This generates all alignments
of pronunciations that were within a beam of the best-scoring pronunciation.
\subsubsection chain_training_numerator_splitting Splitting the numerator FSTs
As mentioned, we train on fixed sized pieces of utterances (e.g. 1.5 seconds in
length). This requires that we split up the numerator FSTs up into fixed-size
pieces. This isn't hard, since the numerator FSTs (which, remember, encode
time-alignment information), naturally have a structure where we can identify
any FST state with a particular frame index. Note: at the stage where we do this
splitting, there are no costs in the numerator FST yet-- it's just viewed as
encoding a constraint on paths-- so we do not have to make a decision how to split up the costs
on the paths.
\subsubsection chain_training_numerator_normalization Normalizing the numerator FSTs
Above (\ref chain_training_denominator_compilation) we mentioned how we compute
initial and final probabilities for the denominator FST, and how we encode
these in a 'normalization FST'. We compose the split-up pieces of numerator
FST with this this 'normalization FST' to ensure that the costs from the
denominator FST are reflected in the numerator FST. This ensures that
objective functions can never be positive (which makes them easier to
interpret), and also guards against the possibility that the numerator FST
could contain state sequences not allowed by the denominator FST, which in
principle could allow the objective function to increase without bound. The
reason why this could happen is that the phone LM lacks smoothing, and is
estimated from 1-best alignments, so the lattices could contain phone n-grams
sequences not seen in training.
It happens occasionally (but very rarely) that this normalization process
generates an empty FST: this can occur when the lattice contains triphones that
were not not present in the 1-best alignment used to train the phone language
model, and does not have any alternative paths at that point in the lattice
that could make up for the resulting 'failed' paths. This can happen because
the 1-best alignment and the lattice-producing alignment chose different
pronunciations of a word. These pieces of utterances are just discarded.
\subsubsection chain_training_numerator_format Format of the numerator FSTs
The numerator FSTs are weighted acceptors where the labels correspond to
pdf-ids plus one. We can't use pdf-ids, because they could be zero; and zero
is treated specially (as epsilon) by OpenFst. When we form minibatches, instead
of storing an array of separate numerator FSTs we actually append them together to form a longer FST;
this enables us to do a single forward-backward over all utterances in the minibatch,
which directly computes the total numerator log-probability. (This isn't an important
feature, it's just a software detail, which we explain here lest it generate confusion).
\subsection chain_training_splitting Fixed-length chunks, and minibatches
In order to train on minibatches, we split up our utterances into fixed-length
chunks of speech (of length 1.5 seconds in our current scripts). Utterances
shorter than this are discarded; those longer, are split into chunks with
either overlaps between the chunks, or small gaps between the chunks. Note that
our acoustic models typically require left or right frames for acoustic
context; we add that, but this is separate issue; the context is added after
the chunks are decided on.
Our minibatch size is usually a power of 2, and it can be limited by GPU
memory considerations. Many of our example scripts use 128 chunks per
minibatch. The largest single consumer of GPU memory is the alpha
probabilities in the forward-backward computation. For instance, with 1.5
second chunk, we have 50 time steps after the 3-fold subsampling. In our
Switchboard setup a typical denominator FST has 30,000 states in it. We use
single-precision floating point for the alphas, so the memory used in
gigabytes is (128 * 50 * 30000 * 4) / 10^9 = 0.768G.
This won't use up all the GPU memory, but there are other sources of memory,
e.g. we keep around two copies of the nnet outputs in memory, which takes a
fair amount of memory depending on the configuration-- e.g. replace the 30000
above with about 10000 and it will give you the amount of memory used for one
copy of the nnet outputs in a reasonable configuration.
\subsection chain_training_shifting Training on frame-shifted data
In neural net training we already have ways of generating perturbed data to
artificially increase the amount of data we train on. Our standard nnet3
neural-net training example scripts do time-warping of the raw audio, by
factors of 0.9, 1.0 and 1.0, to create 3-fold augmented data. This is
orthogonal to the 'chain' models, and we do it (or not) just as we would for
the baseline. However, there is an extra way we can augment the data for the
chain models, by shifting the frames. The output frame rate for these models
is one third the regular frame rate (configurable, of course), meaning we only
evaluate nnet output at <code>t</code> values that are multiples of 3, so we
can generate different versions of the training data by shifting the training
examples by 0, 1 and 2 frames. This is done automatically in the training
script, and it's done 'on the fly' as we read the training examples from
disk-- the program <code>nnet3-chain-copy-egs</code> has a
<code>--frame-shift</code> option that is set by the script. This affects how
the number of epochs is interpreted. If the user requests, for instance, 4
epochs, then we actually train for 12 epochs; we just do so on 3
differently-shifted versions of the data. What the option
<code>--frame-shift=t</code> option actually does is to shift the input frames
by <code>t</code> and shift the output frames by the closest multple of 3 to
<code>t</code>. (In general it might not be 3, it's a configuration variable
named <code>--frame-subsampling-factor</code>).
\subsection chain_training_gpu GPU issues in training
The parts of the computation that are specific to the 'chain' computation are
the forward-backward over the numerator FST and over the denominator HMM. The
numerator part of this is very fast. The denominator forward-backward takes
quite a lot of time, because there can be a lot of arcs in the
denominator FST (e.g. 200,000 arcs and 30,000 states in a typical Switchboard setup).
The time taken can be almost as much as the time taken in the neural-net
parts of the computation. We were quite careful to ensure memory locality.
The next step to further speed this up is probably to implement a pruned
version of the forward-backward computation (like pruned Viterbi, but
computing posteriors). In order to get a speedup we'd have to prune away a
very high percentage of states, because we'd need to make up for the loss of
memory locality that pruning would bring. In our current implementation we are
careful to ensure that a group of GPU threads are all processing the same
HMM-state and time, just from different chunks (we call these different
'sequences' in the code); and we make sure that the memory locations
corresponding to a these different sequences are all next to each other in
memory, so the GPU can do 'consolidated memory access'. With state-level
pruning, since the memory access for the different sequences would no longer be
'in sync', we would lose this advantage. It should still be doable to get a
pruned version of the forward-backward algorithm, though.
For speed, we don't use log values in the alpha-beta computation for the
denominator graph. In order to keep all the numerical values in a suitable
range, we multiply all the acoustic probabilities (exponentiated nnet outputs)
on each frame, by an 'arbitrary value' selected to ensure that our alpha scores
stay in a good range. We call this an 'arbitrary value' because the algorithm
is designed so that we could choose any value here, and it would still be
mathematically correct. We designate one HMM state as a 'special state', and
the 'arbitrary constant' is chosen is the inverse of that special state's alpha
on the previous frame. This keeps the special state's alpha values close to
one. As the 'special state' we choose a state that has high probability in the
limiting distribution of the HMM, and which can access the majority of states
of the HMM.
\section chain_decoding Decoding with 'chain' models
The decoding process with 'chain' models is exactly the same as for regular nnet3
neural-net based models, and in fact uses the same script (steps/nnet3/decode.sh).
There are a few configuration differences:
- Firstly, the graph is built with a different and simpler topology; but this requires
no special action by the user, as the graph-building script anyway takes the
topology from the 'final.mdl' produced by the 'chain' training script, which
contains the correct topology.
- By default when we compile the graph, we use a 'self-loop-scale' of 0.1.
This affects how the transition probabilities on self-loops are treated
(it generally works better). However, for the 'chain' models, because of
how they were trained, we need to use exactly the same
transition-probability scaling we trained with, which for simplicity we
have set to 1.0. So we supply the option <code>--self-loop-scale
1.0</code> to the <code>utils/mkgraph.sh</code> script.
- There is no 'division by the prior' necessary in these models. So we simply
don't set the vector of priors in the <code>.mdl</code> files; we made sure
that the decoder just omits the division by the prior if the priors are not set.
- The default acoustic scale we typically use in decoding (0.1) is not
suitable-- for 'chain' models the optimal acoustic scale is very close to 1.
So we supply the option <code>--acwt 1.0</code> to the script
<code>steps/nnet3/decode.sh</code>.
- The scoring scripts can only search the language-model scale in increments
of 1, which works well in typical setups where the optimal language model scale
is between 10 and 15, but not when the optimal language-model scale is close
to 1 as it here. (Note: for current purposes you can treat the language-model
scale as the same as the inverse of the acoustic scale). In order to
work around this issue without changing the scoring scripts (which are
database-specific), we supply a new option <code>--post-decode-acwt 10.0</code>
to the script <code>steps/nnet3/decode.sh</code>,
which scales the acoustic probabilities by 10 before dumping the lattice.
After this, the optimal language-model scale will be around 10, which might
be a little confusing if you are not aware of this issue, but is convenient
for the way the scoring scripts are set up.
- The default decoding and lattice beams are suitable without modification
for the 'chain' models, once you use the <code>--acwt 1.0</code> option.
However, they won't show the full possible speedup and you can get faster
decoding by using slightly tighter beams. By tightening the beam in the
Switchboard setup we were able to get decoding time down from around 1.5
times real time to around 0.5 times real time, with only around 0.2\%
degradation in accuracy (this was with neural net evaluation on the CPU; on
the GPU it would have been even faster). Note from Dan: this is all to the best
of my recollection as I write this; actually the degradation may have been more than
that. And bear in mind that this was on high-powered modern server machines
(single-threaded).
You might notice in the current example scripts that we use iVectors. We do so
just because they generally help a bit, and because the baseline setup we were
comparing with, uses them. There is no inherent connection with 'chain'
models, and no fundamental requirement to use them.
*/
}