cublas-wrappers.h 14.2 KB
// cudamatrix/cublas-wrappers.h

// Copyright 2013  Johns Hopkins University (author: Daniel Povey);
//           2017  Shiyin Kang

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

//  http://www.apache.org/licenses/LICENSE-2.0

// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_CUDAMATRIX_CUBLAS_WRAPPERS_H_
#define KALDI_CUDAMATRIX_CUBLAS_WRAPPERS_H_ 1

// Do not include this file directly.  It is to be included
// by .cc files in this directory.

namespace kaldi {
#if HAVE_CUDA == 1

inline cublasStatus_t cublas_gemm(
    cublasHandle_t handle, cublasOperation_t transa,
    cublasOperation_t transb, int m, int n,int k, float alpha,
    const float *A, int lda, const float *B, int ldb, float beta,
    float *C, int ldc) {
  return cublasSgemm_v2(handle,transa,transb,m,n,k,&alpha,A,lda,B,ldb,&beta,C,ldc);
}
inline cublasStatus_t cublas_gemm(
    cublasHandle_t handle, cublasOperation_t transa,
    cublasOperation_t transb, int m, int n,int k, double alpha,
    const double *A, int lda, const double *B, int ldb, double beta,
    double *C, int ldc) {
  return cublasDgemm_v2(handle,transa,transb,m,n,k,&alpha,A,lda,B,ldb,&beta,C,ldc);
}
inline cublasStatus_t cublas_ger(
    cublasHandle_t handle, int m, int n, float alpha,
    const float *x, int incx, const float *y, int incy, float *A, int lda ) {
  return cublasSger_v2(handle,m,n,&alpha,x,incx,y,incy,A,lda);
}
inline cublasStatus_t cublas_ger(cublasHandle_t handle, int m, int n, double alpha,
        const double *x, int incx, const double *y, int incy, double *A, int lda ) {
  return cublasDger_v2(handle,m,n,&alpha,x,incx,y,incy,A,lda);
}
inline cublasStatus_t cublas_gemmBatched(
    cublasHandle_t handle, cublasOperation_t transa,
    cublasOperation_t transb, int m, int n, int k, float alpha,
    const float *A[], int lda, const float *B[], int ldb, float beta,
    float *C[], int ldc, int batchCount) {
  return cublasSgemmBatched(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc, batchCount);
}
inline cublasStatus_t cublas_gemmBatched(
    cublasHandle_t handle, cublasOperation_t transa,
    cublasOperation_t transb, int m, int n, int k, double alpha,
    const double *A[], int lda, const double *B[], int ldb, double beta,
    double *C[], int ldc, int batchCount) {
  return cublasDgemmBatched(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc, batchCount);
}
inline cublasStatus_t cublas_trsm(cublasHandle_t handle, int m, int n,
                                  float alpha, const float* A, int lda,
                                  float* B, int ldb) {
  return cublasStrsm_v2(handle,CUBLAS_SIDE_LEFT,CUBLAS_FILL_MODE_UPPER,CUBLAS_OP_N,CUBLAS_DIAG_NON_UNIT,m,n,&alpha,A,lda,B,ldb);
}
inline cublasStatus_t cublas_trsm(cublasHandle_t handle, int m, int n,
                                  double alpha, const double* A, int lda,
                                  double* B, int ldb) {
  return cublasDtrsm_v2(handle,CUBLAS_SIDE_LEFT,CUBLAS_FILL_MODE_UPPER,CUBLAS_OP_N,CUBLAS_DIAG_NON_UNIT,m,n,&alpha,A,lda,B,ldb);
}
inline cublasStatus_t cublas_syrk(
    cublasHandle_t handle, cublasFillMode_t uplo,
    cublasOperation_t trans, int n, int k, float alpha,
    const float *A, int lda, float beta, float *C, int ldc) {
  return cublasSsyrk_v2(handle,uplo,trans,n,k,&alpha,A,lda,&beta,C,ldc);
}
inline cublasStatus_t cublas_syrk(
    cublasHandle_t handle, cublasFillMode_t uplo,
    cublasOperation_t trans, int n, int k, double alpha,
    const double *A, int lda, double beta, double *C, int ldc) {
  return cublasDsyrk_v2(handle,uplo,trans,n,k,&alpha,A,lda,&beta,C,ldc);
}
inline cublasStatus_t cublas_dot(cublasHandle_t handle, int n, const float *x,
                                 int incx, const float *y, int incy,
                                 float *result) {
  return cublasSdot_v2(handle, n, x, incx, y, incy, result);
}
inline cublasStatus_t cublas_dot(cublasHandle_t handle, int n, const double *x,
                                 int incx, const double *y, int incy,
                                 double *result) {
  return cublasDdot_v2(handle, n, x, incx, y, incy, result);
}
inline cublasStatus_t cublas_asum(cublasHandle_t handle, int n, const float* x,
                                  int incx, float *result) {
  return cublasSasum_v2(handle, n, x, incx, result);
}
inline cublasStatus_t cublas_asum(cublasHandle_t handle, int n, const double* x,
                                  int incx, double *result) {
  return cublasDasum_v2(handle, n, x, incx, result);
}
inline cublasStatus_t cublas_nrm2(cublasHandle_t handle, int n, const float* x,
                                  int incx, float *result) {
  return cublasSnrm2_v2(handle, n, x, incx, result);
}
inline cublasStatus_t cublas_nrm2(cublasHandle_t handle, int n, const double* x,
                                  int incx, double *result) {
  return cublasDnrm2_v2(handle, n, x, incx, result);
}
inline cudaError_t cublas_copy(cublasHandle_t handle, int n, const float* x,
    int incx, double* y, int incy) {
  int dimBlock(CU1DBLOCK);
  int dimGrid(n_blocks(n, CU1DBLOCK));
  cublas_copy_kaldi_fd(dimGrid, dimBlock, n, x, incx, y, incy);
  return cudaGetLastError();
}
inline cudaError_t cublas_copy(cublasHandle_t handle, int n, const double* x,
    int incx, float* y, int incy) {
  int dimBlock(CU1DBLOCK);
  int dimGrid(n_blocks(n, CU1DBLOCK));
  cublas_copy_kaldi_df(dimGrid, dimBlock, n, x, incx, y, incy);
  return cudaGetLastError();
}
inline cublasStatus_t cublas_copy(cublasHandle_t handle, int n, const float* x,
                                  int incx, float* y, int incy) {
  return cublasScopy_v2(handle,n,x,incx,y,incy);
}
inline cublasStatus_t cublas_copy(cublasHandle_t handle, int n, const double* x,
                                  int incx, double* y, int incy) {
  return cublasDcopy_v2(handle,n,x,incx,y,incy);
}
inline cublasStatus_t cublas_scal(cublasHandle_t handle, int n, float alpha,
                                  float* mat, int incx) {
  return cublasSscal_v2(handle, n, &alpha, mat, incx);
}
inline cublasStatus_t cublas_scal(cublasHandle_t handle, int n, double alpha,
                                  double* mat, int incx) {
  return cublasDscal_v2(handle, n, &alpha, mat, incx);
}

inline cublasStatus_t cublas_axpy(cublasHandle_t handle, int n, float alpha,
                                  const float* x, int incx, float* y, int incy) {
  return cublasSaxpy_v2(handle, n, &alpha, x, incx, y, incy);
}
inline cublasStatus_t cublas_axpy(cublasHandle_t handle, int n, double alpha,
                                  const double* x, int incx, double* y, int incy) {
  return cublasDaxpy_v2(handle, n, &alpha, x, incx, y, incy);
}
inline cublasStatus_t cublas_gemv(
    cublasHandle_t handle, cublasOperation_t trans,
    int m, int n, float alpha, const float* A, int lda, const float* x,
    int incx, float beta, float* y, int incy) {
  return cublasSgemv_v2(handle,trans,m,n,&alpha,A,lda,x,incx,&beta,y,incy);
}
inline cublasStatus_t cublas_gemv(
    cublasHandle_t handle, cublasOperation_t trans,
    int m, int n, double alpha, const double* A, int lda, const double* x,
    int incx, double beta, double* y, int incy) {
  return cublasDgemv_v2(handle,trans,m,n,&alpha,A,lda,x,incx,&beta,y,incy);
}

inline cublasStatus_t cublas_spmv(
    cublasHandle_t handle, cublasFillMode_t uplo,
    int n, float alpha, const float *AP, const float *x, int incx,
    float beta, float *y, int incy) {
  return cublasSspmv_v2(handle, uplo, n, &alpha, AP, x, incx, &beta, y, incy);
}
inline cublasStatus_t cublas_spmv(
    cublasHandle_t handle, cublasFillMode_t uplo,
    int n, double alpha, const double *AP, const double *x, int incx,
    double beta, double *y, int incy) {
  return cublasDspmv_v2(handle, uplo, n, &alpha, AP, x, incx, &beta, y, incy);
}

// Use caution with these, the 'transpose' argument is the opposite of what it
// should really be, due to CUDA storing things in column major order.  We also
// had to switch 'l' to 'u'; we view our packed matrices as lower-triangular,
// row-by-row, but CUDA views the same layout as upper-triangular,
// column-by-column.
inline cublasStatus_t cublas_tpmv(cublasHandle_t handle, cublasOperation_t trans,
                                  int n, const float* Ap, float* x, int incx) {
  return cublasStpmv_v2(handle, CUBLAS_FILL_MODE_UPPER, trans, CUBLAS_DIAG_NON_UNIT, n, Ap, x, incx);
}
inline cublasStatus_t cublas_tpmv(cublasHandle_t handle, cublasOperation_t trans,
                                  int n, const double* Ap, double* x,int incx) {
  return cublasDtpmv_v2(handle, CUBLAS_FILL_MODE_UPPER, trans, CUBLAS_DIAG_NON_UNIT, n, Ap, x, incx);
}

inline cublasStatus_t cublas_spr(cublasHandle_t handle, cublasFillMode_t uplo,
                                 int n, float alpha, const float *x, int incx,
                                 float *AP) {
  return cublasSspr_v2(handle, uplo, n, &alpha, x, incx, AP);
}
inline cublasStatus_t cublas_spr(cublasHandle_t handle, cublasFillMode_t uplo,
                                 int n, double alpha, const double *x, int incx,
                                 double *AP) {
  return cublasDspr_v2(handle, uplo, n, &alpha, x, incx, AP);
}

//
// cuSPARSE wrappers
//

inline cusparseStatus_t cusparse_csr2csc(cusparseHandle_t handle, int m, int n,
                                         int nnz, const float *csrVal,
                                         const int *csrRowPtr,
                                         const int *csrColInd, float *cscVal,
                                         int *cscRowInd, int *cscColPtr,
                                         cusparseAction_t copyValues,
                                         cusparseIndexBase_t idxBase) {
  return cusparseScsr2csc(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd,
                          cscVal, cscRowInd, cscColPtr, copyValues, idxBase);
}
inline cusparseStatus_t cusparse_csr2csc(cusparseHandle_t handle, int m, int n,
                                         int nnz, const double *csrVal,
                                         const int *csrRowPtr,
                                         const int *csrColInd, double *cscVal,
                                         int *cscRowInd, int *cscColPtr,
                                         cusparseAction_t copyValues,
                                         cusparseIndexBase_t idxBase) {
  return cusparseDcsr2csc(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd,
                          cscVal, cscRowInd, cscColPtr, copyValues, idxBase);
}

inline cusparseStatus_t cusparse_csrmm(cusparseHandle_t handle,
                                       cusparseOperation_t transA, int m, int n,
                                       int k, int nnz, const float *alpha,
                                       const cusparseMatDescr_t descrA,
                                       const float *csrValA,
                                       const int *csrRowPtrA,
                                       const int *csrColIndA, const float *B,
                                       int ldb, const float *beta, float *C,
                                       int ldc) {
  return cusparseScsrmm(handle, transA, m, n, k, nnz, alpha, descrA, csrValA,
                        csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc);
}
inline cusparseStatus_t cusparse_csrmm(cusparseHandle_t handle,
                                       cusparseOperation_t transA, int m, int n,
                                       int k, int nnz, const double *alpha,
                                       const cusparseMatDescr_t descrA,
                                       const double *csrValA,
                                       const int *csrRowPtrA,
                                       const int *csrColIndA, const double *B,
                                       int ldb, const double *beta, double *C,
                                       int ldc) {
  return cusparseDcsrmm(handle, transA, m, n, k, nnz, alpha, descrA, csrValA,
                        csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc);
}

inline cusparseStatus_t cusparse_csrmm2(cusparseHandle_t handle,
                                        cusparseOperation_t transA,
                                        cusparseOperation_t transB, int m,
                                        int n, int k, int nnz,
                                        const float *alpha,
                                        const cusparseMatDescr_t descrA,
                                        const float *csrValA,
                                        const int *csrRowPtrA,
                                        const int *csrColIndA, const float *B,
                                        int ldb, const float *beta, float *C,
                                        int ldc) {
  return cusparseScsrmm2(handle, transA, transB, m, n, k, nnz, alpha, descrA,
                         csrValA, csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc);
}
inline cusparseStatus_t cusparse_csrmm2(cusparseHandle_t handle,
                                        cusparseOperation_t transA,
                                        cusparseOperation_t transB, int m,
                                        int n, int k, int nnz,
                                        const double *alpha,
                                        const cusparseMatDescr_t descrA,
                                        const double *csrValA,
                                        const int *csrRowPtrA,
                                        const int *csrColIndA, const double *B,
                                        int ldb, const double *beta, double *C,
                                        int ldc) {
  return cusparseDcsrmm2(handle, transA, transB, m, n, k, nnz, alpha, descrA,
                         csrValA, csrRowPtrA, csrColIndA, B, ldb, beta, C, ldc);
}


#endif
}
// namespace kaldi

#endif