13 #ifndef QMCPLUSPLUS_CUBLAS_H 14 #define QMCPLUSPLUS_CUBLAS_H 22 #include <cublas_v2.h> 23 #define castNativeType castCUDAType 25 #include <hipblas/hipblas.h> 29 #define castNativeType casthipblasType 34 #define cublasErrorCheck(ans, cause) \ 36 cublasAssert((ans), cause, __FILE__, __LINE__); \ 44 std::string cublas_error;
48 cublas_error =
"CUBLAS_STATUS_NOT_INITIALIZED";
51 cublas_error =
"CUBLAS_STATUS_ALLOC_FAILED";
54 cublas_error =
"CUBLAS_STATUS_INVALID_VALUE";
57 cublas_error =
"CUBLAS_STATUS_ARCH_MISMATCH";
60 cublas_error =
"CUBLAS_STATUS_MAPPING_ERROR";
63 cublas_error =
"CUBLAS_STATUS_EXECUTION_FAILED";
66 cublas_error =
"CUBLAS_STATUS_INTERNAL_ERROR";
69 cublas_error =
"CUBLAS_STATUS_NOT_SUPPORTED";
72 case CUBLAS_STATUS_LICENSE_ERROR:
73 cublas_error =
"CUBLAS_STATUS_LICENSE_ERROR";
77 cublas_error =
"<unknown>";
80 std::ostringstream err;
81 err <<
"cublasAssert: " << cublas_error <<
", file " << file <<
" , line " << line << std::endl
82 << cause << std::endl;
83 std::cerr << err.str();
85 throw std::runtime_error(cause);
98 if (trans ==
'N' || trans ==
'n')
100 else if (trans ==
'T' || trans ==
't')
102 else if (trans ==
'C' || trans ==
'c')
105 throw std::runtime_error(
106 "cuBLAS::convertOperation trans can only be 'N', 'T', 'C', 'n', 't', 'c'. Input value is " +
107 std::string(1, trans));
124 return cublasSgeam(handle, transa, transb,
m,
n, alpha,
A,
lda, beta,
B, ldb,
C, ldc);
141 return cublasDgeam(handle, transa, transb,
m,
n, alpha,
A,
lda, beta,
B, ldb,
C, ldc);
149 const std::complex<double>* alpha,
150 const std::complex<double>*
A,
152 const std::complex<double>* beta,
153 const std::complex<double>*
B,
155 std::complex<double>*
C,
167 const std::complex<float>* alpha,
168 const std::complex<float>*
A,
170 const std::complex<float>* beta,
171 const std::complex<float>*
B,
173 std::complex<float>*
C,
204 std::complex<float>*
A[],
215 std::complex<double>*
A[],
252 std::complex<float>*
A[],
255 std::complex<float>*
C[],
265 std::complex<double>*
A[],
268 std::complex<double>*
C[],
279 #undef castNativeType 280 #endif // QMCPLUSPLUS_CUBLAS_H
#define cublasCgetriBatched
helper functions for EinsplineSetBuilder
#define CUBLAS_STATUS_INVALID_VALUE
cublasStatus_t getrf_batched(cublasHandle_t &handle, int n, float *A[], int lda, int *PivotArray, int *infoArray, int batchSize)
#define CUBLAS_STATUS_ALLOC_FAILED
#define cublasDgetriBatched
#define cublasSgetrfBatched
#define CUBLAS_STATUS_SUCCESS
#define cublasCgetrfBatched
#define CUBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_MAPPING_ERROR
#define cublasOperation_t
#define CUBLAS_STATUS_NOT_SUPPORTED
void cublasAssert(cublasStatus_t code, const std::string &cause, const char *file, int line, bool abort=true)
prints cuBLAS error messages. Always use cublasErrorCheck macro.
#define CUBLAS_STATUS_NOT_INITIALIZED
#define cublasZgetriBatched
CUDATypeMap< T > castCUDAType(T var)
#define cublasZgetrfBatched
cublasStatus_t getri_batched(cublasHandle_t &handle, int n, float *A[], int lda, int *PivotArray, float *C[], int ldc, int *infoArray, int batchSize)
cublasStatus_t geam(cublasHandle_t &handle, cublasOperation_t &transa, cublasOperation_t &transb, int m, int n, const float *alpha, const float *A, int lda, const float *beta, const float *B, int ldb, float *C, int ldc)
#define cublasDgetrfBatched
#define CUBLAS_STATUS_INTERNAL_ERROR
cublasOperation_t convertOperation(const char trans)
double B(double x, int k, int i, const std::vector< double > &t)
#define CUBLAS_STATUS_ARCH_MISMATCH
#define cublasSgetriBatched