Blame view

src/transform/decodable-am-diag-gmm-regtree.cc 8.98 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
  // transform/decodable-am-diag-gmm-regtree.cc
  
  // Copyright 2009-2011  Saarland University;  Lukas Burget
  //                2013  Johns Hopkins Universith (author: Daniel Povey)
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  #include <vector>
  using std::vector;
  
  #include "transform/decodable-am-diag-gmm-regtree.h"
  
  namespace kaldi {
  
  
  BaseFloat DecodableAmDiagGmmRegtreeFmllr::LogLikelihoodZeroBased(int32 frame,
                                                            int32 state) {
    KALDI_ASSERT(frame < NumFramesReady() && frame >= 0);
    KALDI_ASSERT(state < NumIndices() && state >= 0);
  
    if (!valid_logdets_) {
      logdets_.Resize(fmllr_xform_.NumRegClasses());
      fmllr_xform_.GetLogDets(&logdets_);
      valid_logdets_ = true;
    }
  
    if (log_like_cache_[state].hit_time == frame) {
      return log_like_cache_[state].log_like;  // return cached value, if found
    }
  
    const DiagGmm &pdf = acoustic_model_.GetPdf(state);
    const VectorBase<BaseFloat> &data = feature_matrix_.Row(frame);
  
    // check if everything is in order
    if (pdf.Dim() != data.Dim()) {
      KALDI_ERR << "Dim mismatch: data dim = "  << data.Dim()
          << " vs. model dim = " << pdf.Dim();
    }
    if (!pdf.valid_gconsts()) {
      KALDI_ERR << "State "  << (state)  << ": Must call ComputeGconsts() "
          "before computing likelihood.";
    }
  
    if (frame != previous_frame_) {  // cache the transformed & squared stats.
      fmllr_xform_.TransformFeature(data, &xformed_data_);
      xformed_data_squared_ = xformed_data_;
      vector< Vector <BaseFloat> >::iterator it = xformed_data_squared_.begin(),
          end = xformed_data_squared_.end();
      for (; it != end; ++it) { it->ApplyPow(2.0); }
      previous_frame_ = frame;
    }
  
    Vector<BaseFloat> loglikes(pdf.gconsts());  // need to recreate for each pdf
    int32 baseclass, regclass;
    for (int32 comp_id = 0, num_comp = pdf.NumGauss(); comp_id < num_comp;
        ++comp_id) {
      baseclass = regtree_.Gauss2BaseclassId(state, comp_id);
      regclass = fmllr_xform_.Base2RegClass(baseclass);
      // loglikes +=  means * inv(vars) * data.
      loglikes(comp_id) += VecVec(pdf.means_invvars().Row(comp_id),
                                  xformed_data_[regclass]);
      // loglikes += -0.5 * inv(vars) * data_sq.
      loglikes(comp_id) -= 0.5 * VecVec(pdf.inv_vars().Row(comp_id),
                                        xformed_data_squared_[regclass]);
      loglikes(comp_id) += logdets_(regclass);
    }
  
    BaseFloat log_sum = loglikes.LogSumExp(log_sum_exp_prune_);
    if (KALDI_ISNAN(log_sum) || KALDI_ISINF(log_sum))
      KALDI_ERR << "Invalid answer (overflow or invalid variances/features?)";
  
    log_like_cache_[state].log_like = log_sum;
    log_like_cache_[state].hit_time = frame;
  
    return log_sum;
  }
  
  DecodableAmDiagGmmRegtreeMllr::~DecodableAmDiagGmmRegtreeMllr() {
    DeletePointers(&xformed_mean_invvars_);
    DeletePointers(&xformed_gconsts_);
  }
  
  
  void DecodableAmDiagGmmRegtreeMllr::InitCache() {
    if (xformed_mean_invvars_.size() != 0)
      DeletePointers(&xformed_mean_invvars_);
    if (xformed_gconsts_.size() != 0)
      DeletePointers(&xformed_gconsts_);
    int32 num_pdfs = acoustic_model_.NumPdfs();
    xformed_mean_invvars_.resize(num_pdfs);
    xformed_gconsts_.resize(num_pdfs);
    is_cached_.resize(num_pdfs, false);
    ResetLogLikeCache();
  }
  
  
  // This is almost the same code as DiagGmm::ComputeGconsts, except that
  // means are used instead of means * inv(vars). This saves some computation.
  static void ComputeGconsts(const VectorBase<BaseFloat> &weights,
                             const MatrixBase<BaseFloat> &means,
                             const MatrixBase<BaseFloat> &inv_vars,
                             VectorBase<BaseFloat> *gconsts_out) {
    int32 num_gauss = weights.Dim();
    int32 dim = means.NumCols();
    KALDI_ASSERT(means.NumRows() == num_gauss
        && inv_vars.NumRows() == num_gauss && inv_vars.NumCols() == dim);
    KALDI_ASSERT(gconsts_out->Dim() == num_gauss);
  
    BaseFloat offset = -0.5 * M_LOG_2PI * dim;  // constant term in gconst.
    int32 num_bad = 0;
  
    for (int32 gauss = 0; gauss < num_gauss; gauss++) {
      KALDI_ASSERT(weights(gauss) >= 0);  // Cannot have negative weights.
      BaseFloat gc = Log(weights(gauss)) + offset;  // May be -inf if weights == 0
      for (int32 d = 0; d < dim; d++) {
        gc += 0.5 * Log(inv_vars(gauss, d)) - 0.5 * means(gauss, d)
          * means(gauss, d) * inv_vars(gauss, d);  // diff from DiagGmm version.
      }
  
      if (KALDI_ISNAN(gc)) {  // negative infinity is OK but NaN is not acceptable
        KALDI_ERR << "At component "  << gauss
                  << ", not a number in gconst computation";
      }
      if (KALDI_ISINF(gc)) {
        num_bad++;
        // If positive infinity, make it negative infinity.
        // Want to make sure the answer becomes -inf in the end, not NaN.
        if (gc > 0) gc = -gc;
      }
      (*gconsts_out)(gauss) = gc;
    }
    if (num_bad > 0)
      KALDI_WARN << num_bad << " unusable components found while computing "
                 << "gconsts.";
  }
  
  
  const Matrix<BaseFloat>& DecodableAmDiagGmmRegtreeMllr::GetXformedMeanInvVars(
      int32 state) {
    if (is_cached_[state]) {  // found in cache
      KALDI_ASSERT(xformed_mean_invvars_[state] != NULL);
      KALDI_VLOG(3) << "For PDF index " << state << ": transformed means "
                    << "found in cache.";
      return *xformed_mean_invvars_[state];
    } else {  // transform the means and cache them
      KALDI_ASSERT(xformed_mean_invvars_[state] == NULL);
      KALDI_VLOG(3) << "For PDF index " << state << ": transforming means.";
      int32 num_gauss = acoustic_model_.GetPdf(state).NumGauss(),
          dim = acoustic_model_.Dim();
      const Vector<BaseFloat> &weights = acoustic_model_.GetPdf(state).weights();
      const Matrix<BaseFloat> &invvars = acoustic_model_.GetPdf(state).inv_vars();
      xformed_mean_invvars_[state] = new Matrix<BaseFloat>(num_gauss, dim);
      mllr_xform_.GetTransformedMeans(regtree_, acoustic_model_, state,
                                      xformed_mean_invvars_[state]);
      xformed_gconsts_[state] = new Vector<BaseFloat>(num_gauss);
      // At this point, the transformed means haven't been multiplied with
      // the inv vars, and they are used to compute gconsts first.
      ComputeGconsts(weights, *xformed_mean_invvars_[state], invvars,
                     xformed_gconsts_[state]);
      // Finally, multiply the transformed means with the inv vars.
      xformed_mean_invvars_[state]->MulElements(invvars);
      is_cached_[state] = true;
      return *xformed_mean_invvars_[state];
    }
  }
  
  const Vector<BaseFloat>& DecodableAmDiagGmmRegtreeMllr::GetXformedGconsts(
      int32 state) {
    if (!is_cached_[state]) {
      KALDI_ERR << "GConsts not cached for state: " << state << ". Must call "
                << "GetXformedMeanInvVars() first.";
    }
    KALDI_ASSERT(xformed_gconsts_[state] != NULL);
    return *xformed_gconsts_[state];
  }
  
  BaseFloat DecodableAmDiagGmmRegtreeMllr::LogLikelihoodZeroBased(int32 frame,
                                                                  int32 state) {
  //  KALDI_ERR << "Function not completely implemented yet.";
    KALDI_ASSERT(frame < NumFramesReady() && frame >= 0);
    KALDI_ASSERT(state < NumIndices() && state >= 0);
  
    if (log_like_cache_[state].hit_time == frame) {
      return log_like_cache_[state].log_like;  // return cached value, if found
    }
  
    const DiagGmm &pdf = acoustic_model_.GetPdf(state);
    const VectorBase<BaseFloat> &data = feature_matrix_.Row(frame);
  
    // check if everything is in order
    if (pdf.Dim() != data.Dim()) {
      KALDI_ERR << "Dim mismatch: data dim = "  << data.Dim()
          << " vs. model dim = " << pdf.Dim();
    }
  
    if (frame != previous_frame_) {  // cache the squared stats.
      data_squared_.CopyFromVec(feature_matrix_.Row(frame));
      data_squared_.ApplyPow(2.0);
      previous_frame_ = frame;
    }
  
    const Matrix<BaseFloat> &means_invvars = GetXformedMeanInvVars(state);
    const Vector<BaseFloat> &gconsts = GetXformedGconsts(state);
  
    Vector<BaseFloat> loglikes(gconsts);  // need to recreate for each pdf
    // loglikes +=  means * inv(vars) * data.
    loglikes.AddMatVec(1.0, means_invvars, kNoTrans, data, 1.0);
    // loglikes += -0.5 * inv(vars) * data_sq.
    loglikes.AddMatVec(-0.5, pdf.inv_vars(), kNoTrans, data_squared_, 1.0);
  
    BaseFloat log_sum = loglikes.LogSumExp(log_sum_exp_prune_);
    if (KALDI_ISNAN(log_sum) || KALDI_ISINF(log_sum))
      KALDI_ERR << "Invalid answer (overflow or invalid variances/features?)";
  
    log_like_cache_[state].log_like = log_sum;
    log_like_cache_[state].hit_time = frame;
  
    return log_sum;
  }
  
  }  // namespace kaldi