// cudamatrix/cu-sp-matrix.h // Copyright 2009-2013 Karel Vesely // 2014 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_CUDAMATRIX_CU_SP_MATRIX_H_ #define KALDI_CUDAMATRIX_CU_SP_MATRIX_H_ #include #include "cudamatrix/cu-common.h" #include "matrix/matrix-common.h" #include "matrix/sp-matrix.h" #include "cudamatrix/cu-array.h" #include "cudamatrix/cu-math.h" #include "cudamatrix/cu-packed-matrix.h" #include "cudamatrix/cu-matrix.h" namespace kaldi { /// TraceSpSp returns tr(A B) template Real TraceSpSp(const CuSpMatrix &A, const CuSpMatrix &B); template class CuSpMatrix : public CuPackedMatrix { friend class CuMatrixBase; friend class CuVectorBase; friend class CuTpMatrix; friend class CuSubMatrix; friend class CuRand; template friend R TraceSpSp(const CuSpMatrix &A, const CuSpMatrix &B); public: CuSpMatrix(): CuPackedMatrix() {} explicit CuSpMatrix(MatrixIndexT r, MatrixResizeType resize_type = kSetZero) : CuPackedMatrix(r, resize_type) {} explicit CuSpMatrix(const SpMatrix &orig) : CuPackedMatrix(orig) {} // This constructor lacks the "explicit" keyword so that // we can include it in std::vector. CuSpMatrix(const CuSpMatrix &orig) : CuPackedMatrix(orig) {} explicit CuSpMatrix(const CuMatrixBase &orig, SpCopyType copy_type = kTakeLower) : CuPackedMatrix(orig.NumRows(), kUndefined) { CopyFromMat(orig, copy_type); } CuSpMatrix &operator = (const CuSpMatrix &in); ~CuSpMatrix() {} inline void Resize(MatrixIndexT nRows, MatrixResizeType resize_type = kSetZero) { CuPackedMatrix::Resize(nRows, resize_type); } Real FrobeniusNorm() const { return sqrt(TraceSpSp(*this, *this)); } bool IsUnit(Real tol = 0.001) const; bool ApproxEqual(const CuSpMatrix &other, Real tol = 0.001) const; void CopyFromSp(const CuSpMatrix &other) { CuPackedMatrix::CopyFromPacked(other); } void CopyFromSp(const SpMatrix &other) { CuPackedMatrix::CopyFromPacked(other); } void CopyFromMat(const CuMatrixBase &orig, SpCopyType copy_type = kTakeLower); void CopyToSp(SpMatrix *dst) const { CuPackedMatrix::CopyToPacked(dst); } inline CuValue operator() (MatrixIndexT r, MatrixIndexT c) { if (static_cast(c) > static_cast(r)) std::swap(c, r); KALDI_ASSERT(static_cast(r) < static_cast(this->num_rows_)); return CuValue(this->data_ + (r * (r+1)) / 2 + c); } inline Real operator() (MatrixIndexT r, MatrixIndexT c) const { if (static_cast(c) > static_cast(r)) std::swap(c, r); KALDI_ASSERT(static_cast(r) < static_cast(this->num_rows_)); return CuValue(this->data_ + (r * (r+1)) / 2 + c); // will be // casted to Real. } /// Note: the CuMatrix version of the Invert() function will only work for /// positive definite matrices; it is based on Cholesky. void Invert(); void AddVec2(const Real alpha, const CuVectorBase &v); void AddMat2(const Real alpha, const CuMatrixBase &M, MatrixTransposeType transM, const Real beta); void AddSp(const Real alpha, const CuSpMatrix &Ma) { this->AddPacked(alpha, Ma); } protected: inline const SpMatrix &Mat() const { return *(reinterpret_cast* >(this)); } inline SpMatrix &Mat() { return *(reinterpret_cast* >(this)); } }; template inline bool ApproxEqual(const CuSpMatrix &A, const CuSpMatrix &B, Real tol = 0.001) { return A.ApproxEqual(B, tol); } template inline void AssertEqual(const CuSpMatrix &A, const CuSpMatrix &B, Real tol = 0.001) { KALDI_ASSERT(ApproxEqual(A, B, tol)); } template SpMatrix::SpMatrix(const CuSpMatrix &cu) { Resize(cu.NumRows()); cu.CopyToSp(this); } } // namespace #endif