// nnet/nnet-kl-hmm.h // Copyright 2013 Idiap Research Institute (Author: David Imseng) // Karlsruhe Institute of Technology (Author: Ngoc Thang Vu) // Brno University of Technology (Author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_NNET_NNET_KL_HMM_H_ #define KALDI_NNET_NNET_KL_HMM_H_ #include #include "nnet/nnet-component.h" #include "cudamatrix/cu-math.h" #include "cudamatrix/cu-rand.h" #include "matrix/kaldi-vector.h" #include "matrix/kaldi-matrix.h" namespace kaldi { namespace nnet1 { class KlHmm : public Component { public: KlHmm(int32 dim_in, int32 dim_out): Component(dim_in, dim_out), kl_stats_(dim_out, dim_in, kSetZero) { } ~KlHmm() { } Component* Copy() const { return new KlHmm(*this); } ComponentType GetType() const { return kKlHmm; } void PropagateFnc(const CuMatrixBase &in, CuMatrixBase *out) { if (kl_inv_q_.NumRows() == 0) { // Copy the CudaMatrix to a Matrix Matrix in_tmp(in.NumRows(), in.NumCols()); in.CopyToMat(&in_tmp); // Check if there are posteriors in the Matrix (check on first row), BaseFloat post_sum = in_tmp.Row(0).Sum(); KALDI_ASSERT(ApproxEqual(post_sum, 1.0)); // Get a tmp Matrix of the stats Matrix kl_stats_tmp(kl_stats_); // Init a vector to get the sum of the rows (for normalization) Vector row_sum(kl_stats_.NumRows(), kSetZero); // Get the sum of the posteriors for normalization row_sum.AddColSumMat(1, kl_stats_tmp); // Apply floor to make sure there is no zero row_sum.ApplyFloor(1e-20); // Invert the sum (to normalize) row_sum.InvertElements(); // Normalizing the statistics vector kl_stats_tmp.MulRowsVec(row_sum); // Apply floor before inversion and logarithm kl_stats_tmp.ApplyFloor(1e-20); // Apply invesion kl_stats_tmp.InvertElements(); // Apply logarithm kl_stats_tmp.ApplyLog(); // Inverted and logged values kl_inv_q_.Resize(kl_stats_.NumRows(), kl_stats_.NumCols()); // Holds now log (1/Q) kl_inv_q_.CopyFromMat(kl_stats_tmp); } // Get the logarithm of the features for the Entropy calculation // Copy the CudaMatrix to a Matrix Matrix in_log_tmp(in.NumRows(), in.NumCols()); in.CopyToMat(&in_log_tmp); // Flooring and log in_log_tmp.ApplyFloor(1e-20); in_log_tmp.ApplyLog(); CuMatrix log_in(in.NumRows(), in.NumCols()); log_in.CopyFromMat(in_log_tmp); // P*logP CuMatrix tmp_entropy(in); tmp_entropy.MulElements(log_in); // Getting the entropy (sum P*logP) CuVector in_entropy(in.NumRows(), kSetZero); in_entropy.AddColSumMat(1, tmp_entropy); // sum P*log (1/Q) out->AddMatMat(1, in, kNoTrans, kl_inv_q_, kTrans, 0); // (sum P*logP) + (sum P*log(1/Q) out->AddVecToCols(1, in_entropy); // return the negative KL-divergence out->Scale(-1); } void BackpropagateFnc(const CuMatrixBase &in, const CuMatrixBase &out, const CuMatrixBase &out_diff, CuMatrixBase *in_diff) { KALDI_ERR << "Unimplemented"; } /// Reads the component content void ReadData(std::istream &is, bool binary) { kl_stats_.Read(is, binary); KALDI_ASSERT(kl_stats_.NumRows() == output_dim_); KALDI_ASSERT(kl_stats_.NumCols() == input_dim_); } /// Writes the component content void WriteData(std::ostream &os, bool binary) const { kl_stats_.Write(os, binary); } /// Set the statistics matrix void SetStats(const Matrix mat) { KALDI_ASSERT(mat.NumRows() == output_dim_); KALDI_ASSERT(mat.NumCols() == input_dim_); kl_stats_.Resize(mat.NumRows(), mat.NumCols()); kl_stats_.CopyFromMat(mat); } /// Accumulate the statistics for KL-HMM paramter estimation, void Accumulate(const Matrix &posteriors, const std::vector &alignment) { KALDI_ASSERT(posteriors.NumRows() == alignment.size()); KALDI_ASSERT(posteriors.NumCols() == kl_stats_.NumCols()); int32 num_frames = alignment.size(); for (int32 i = 0; i < num_frames; i++) { // Casting float posterior to double (fixing numerical issue), Vector temp(posteriors.Row(i)); // Sum the postiors grouped by states from the alignment, kl_stats_.Row(alignment[i]).AddVec(1, temp); } } private: Matrix kl_stats_; CuMatrix kl_inv_q_; }; } // namespace nnet1 } // namespace kaldi #endif // KALDI_NNET_NNET_KL_HMM_H_