train-nnet-ensemble.cc
5.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
// nnet2/train-nnet-ensemble.cc
// Copyright 2012 Johns Hopkins University (author: Daniel Povey)
// 2014 Xiaohui Zhang
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "nnet2/train-nnet-ensemble.h"
#include <numeric> // for std::accumulate
namespace kaldi {
namespace nnet2 {
static inline Int32Pair MakePair(int32 first, int32 second) {
Int32Pair ans;
ans.first = first;
ans.second = second;
return ans;
}
NnetEnsembleTrainer::NnetEnsembleTrainer(
const NnetEnsembleTrainerConfig &config,
std::vector<Nnet*> nnet_ensemble):
config_(config), nnet_ensemble_(nnet_ensemble) {
beta_ = config_.beta;
num_phases_ = 0;
bool first_time = true;
BeginNewPhase(first_time);
}
void NnetEnsembleTrainer::TrainOnExample(const NnetExample &value) {
buffer_.push_back(value);
if (static_cast<int32>(buffer_.size()) == config_.minibatch_size)
TrainOneMinibatch();
}
void NnetEnsembleTrainer::TrainOneMinibatch() {
KALDI_ASSERT(!buffer_.empty());
int32 num_states = nnet_ensemble_[0]->GetComponent(nnet_ensemble_[0]->NumComponents() - 1).OutputDim();
// average of posteriors matrix, storing averaged outputs of net ensemble.
CuMatrix<BaseFloat> post_avg(buffer_.size(), num_states);
updater_ensemble_.reserve(nnet_ensemble_.size());
std::vector<CuMatrix<BaseFloat> > post_mat;
post_mat.resize(nnet_ensemble_.size());
for (int32 i = 0; i < nnet_ensemble_.size(); i++) {
updater_ensemble_.push_back(new NnetUpdater(*(nnet_ensemble_[i]), nnet_ensemble_[i]));
updater_ensemble_[i]->FormatInput(buffer_);
updater_ensemble_[i]->Propagate();
// posterior matrix, storing output of one net.
updater_ensemble_[i]->GetOutput(&post_mat[i]);
CuVector<BaseFloat> row_sum(post_mat[i].NumRows());
post_avg.AddMat(1.0, post_mat[i]);
}
// calculate the interpolated posterios as new supervision labels, and also
// collect the indices of the original supervision labels for later use (calc. objf.).
std::vector<MatrixElement<BaseFloat> > sv_labels;
std::vector<Int32Pair > sv_labels_ind;
sv_labels.reserve(buffer_.size()); // We must have at least this many labels.
sv_labels_ind.reserve(buffer_.size()); // We must have at least this many labels.
for (int32 m = 0; m < buffer_.size(); m++) {
KALDI_ASSERT(buffer_[m].labels.size() == 1 &&
"Currently this code only supports single-frame egs.");
const std::vector<std::pair<int32,BaseFloat> > &labels = buffer_[m].labels[0];
for (size_t i = 0; i < labels.size(); i++) {
MatrixElement<BaseFloat>
tmp = {m, labels[i].first, labels[i].second};
sv_labels.push_back(tmp);
sv_labels_ind.push_back(MakePair(m, labels[i].first));
}
}
post_avg.Scale(1.0 / nnet_ensemble_.size());
post_avg.Scale(beta_);
post_avg.AddElements(1.0, sv_labels);
// calculate the deriv, do backprop, and calculate the objf.
for (int32 i = 0; i < nnet_ensemble_.size(); i++) {
CuMatrix<BaseFloat> tmp_deriv(post_mat[i]);
post_mat[i].ApplyLog();
std::vector<BaseFloat> log_post_correct;
log_post_correct.resize(sv_labels_ind.size());
post_mat[i].Lookup(sv_labels_ind, &(log_post_correct[0]));
BaseFloat log_prob_this_net = std::accumulate(log_post_correct.begin(),
log_post_correct.end(),
static_cast<BaseFloat>(0));
avg_logprob_this_phase_ += log_prob_this_net;
tmp_deriv.InvertElements();
tmp_deriv.MulElements(post_avg);
updater_ensemble_[i]->Backprop(&tmp_deriv);
}
count_this_phase_ += buffer_.size();
buffer_.clear();
minibatches_seen_this_phase_++;
if (minibatches_seen_this_phase_ == config_.minibatches_per_phase) {
avg_logprob_this_phase_ /= static_cast<BaseFloat>(nnet_ensemble_.size());
bool first_time = false;
BeginNewPhase(first_time);
}
}
void NnetEnsembleTrainer::BeginNewPhase(bool first_time) {
if (!first_time)
KALDI_LOG << "Averaged cross-entropy between the supervision labels and the output is "
<< (avg_logprob_this_phase_/count_this_phase_) << " over "
<< count_this_phase_ << " frames, during this phase";
avg_logprob_this_phase_ = 0.0;
count_this_phase_ = 0.0;
minibatches_seen_this_phase_ = 0;
num_phases_++;
}
NnetEnsembleTrainer::~NnetEnsembleTrainer() {
if (!buffer_.empty()) {
KALDI_LOG << "Doing partial minibatch of size "
<< buffer_.size();
TrainOneMinibatch();
if (minibatches_seen_this_phase_ != 0) {
bool first_time = false;
BeginNewPhase(first_time);
}
}
}
} // namespace nnet2
} // namespace kaldi