10 #ifndef QMCPLUSPLUS_CUDA_ACCELBLAS_CUDA_H 11 #define QMCPLUSPLUS_CUDA_ACCELBLAS_CUDA_H 20 #define castNativeType castCUDAType 22 #define castNativeType casthipblasType 65 n, k, &alpha,
A,
lda,
B, ldb, &beta,
C, ldc),
66 "cublasSgemm failed!");
85 n, k, &alpha,
A,
lda,
B, ldb, &beta,
C, ldc),
86 "cublasDgemm failed!");
95 const std::complex<float>& alpha,
96 const std::complex<float>*
A,
98 const std::complex<float>*
B,
100 const std::complex<float>& beta,
101 std::complex<float>*
C,
107 "cublasCgemm failed!");
116 const std::complex<double>& alpha,
117 const std::complex<double>*
A,
119 const std::complex<double>*
B,
121 const std::complex<double>& beta,
122 std::complex<double>*
C,
128 "cublasZgemm failed!");
136 const float*
const A,
138 const float*
const x,
144 cublasErrorCheck(
cublasSgemv(handle.
h_cublas,
cuBLAS::convertOperation(trans),
m,
n, &alpha,
A,
lda, x, incx, &beta,
146 "cublasSgemv failed!");
154 const double*
const A,
156 const double*
const x,
162 cublasErrorCheck(
cublasDgemv(handle.
h_cublas,
cuBLAS::convertOperation(trans),
m,
n, &alpha,
A,
lda, x, incx, &beta,
164 "cublasDgemv failed!");
171 const std::complex<float>& alpha,
172 const std::complex<float>*
A,
174 const std::complex<float>* x,
176 const std::complex<float>& beta,
177 std::complex<float>* y,
183 "cublasCgemv failed!");
190 const std::complex<double>& alpha,
191 const std::complex<double>*
A,
193 const std::complex<double>* x,
195 const std::complex<double>& beta,
196 std::complex<double>* y,
202 "cublasZgemv failed!");
218 const int batch_count)
220 cudaErrorCheck(
cuBLAS_MFs::gemv_batched(handle.
h_stream, trans,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy,
222 "cuBLAS_MFs::gemv_batched failed!");
229 const float*
const x,
231 const float*
const y,
236 cublasErrorCheck(
cublasSger(handle.
h_cublas,
m,
n, &alpha, x, incx, y, incy,
A,
lda),
"cublasSger failed!");
243 const double*
const x,
245 const double*
const y,
250 cublasErrorCheck(
cublasDger(handle.
h_cublas,
m,
n, &alpha, x, incx, y, incy,
A,
lda),
"cublasDger failed!");
256 const std::complex<float>& alpha,
257 const std::complex<float>* x,
259 const std::complex<float>* y,
261 std::complex<float>*
A,
266 "cublasCger failed!");
272 const std::complex<double>& alpha,
273 const std::complex<double>* x,
275 const std::complex<double>* y,
277 std::complex<double>*
A,
282 "cublasZger failed!");
296 const int batch_count)
298 cudaErrorCheck(
cuBLAS_MFs::ger_batched(handle.
h_stream,
m,
n, alpha, x, incx, y, incy,
A,
lda, batch_count),
299 "cuBLAS_MFs::ger_batched failed!");
309 const int batch_count)
312 "cuBLAS_MFs::copy_batched failed!");
322 const float*
const A[],
324 const float*
const B[],
332 cuBLAS::convertOperation(transb),
m,
n, k, &alpha,
A,
lda,
B, ldb, &beta,
C, ldc,
334 "cublasSgemmBatched failed!");
343 const std::complex<float>& alpha,
344 const std::complex<float>*
const A[],
346 const std::complex<float>*
const B[],
348 const std::complex<float>& beta,
349 std::complex<float>*
const C[],
365 "cublasCgemmBatched failed!");
375 const double*
const A[],
377 const double*
const B[],
385 cuBLAS::convertOperation(transb),
m,
n, k, &alpha,
A,
lda,
B, ldb, &beta,
C, ldc,
387 "cublasDgemmBatched failed!");
396 const std::complex<double>& alpha,
397 const std::complex<double>*
const A[],
399 const std::complex<double>*
const B[],
401 const std::complex<double>& beta,
402 std::complex<double>*
const C[],
414 "cublasZgemmBatched failed!");
420 #undef castNativeType void gemm(BLASHandle< PlatformKind::CUDA > &handle, const char transa, const char transb, int m, int n, int k, const float &alpha, const float *A, int lda, const float *B, int ldb, const float &beta, float *C, int ldc)
#define cublasDgemmBatched
helper functions for EinsplineSetBuilder
handle CUDA/HIP runtime selection.
Interfaces to blas library.
#define cublasSgemmBatched
cudaError_t gemv_batched(cudaStream_t handle, const char trans, const int m, const int n, const float *alpha, const float *const A[], const int lda, const float *const x[], const int incx, const float *beta, float *const y[], const int incy, const int batch_count)
Xgemv batched API.
cudaError_t copy_batched(cudaStream_t hstream, const int n, const float *const in[], const int incx, float *const out[], const int incy, const int batch_count)
Xcopy batched API.
void ger(BLASHandle< PlatformKind::CUDA > &handle, const int m, const int n, const float &alpha, const float *const x, const int incx, const float *const y, const int incy, float *const A, const int lda)
void gemv_batched(BLASHandle< PlatformKind::CUDA > &handle, const char trans, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const int batch_count)
cudaErrorCheck(cudaMemcpyAsync(dev_lu.data(), lu.data(), sizeof(decltype(lu)::value_type) *lu.size(), cudaMemcpyHostToDevice, hstream), "cudaMemcpyAsync failed copying log_values to device")
void gemm_batched(BLASHandle< PlatformKind::CUDA > &handle, const char transa, const char transb, int m, int n, int k, const float &alpha, const float *const A[], int lda, const float *const B[], int ldb, const float &beta, float *const C[], int ldc, int batchCount)
const cudaStream_t h_stream
void copy_batched(BLASHandle< PlatformKind::CUDA > &handle, const int n, const T *const in[], const int incx, T *const out[], const int incy, const int batch_count)
typename std::add_pointer< typename std::remove_const< typename std::remove_pointer< CT >::type >::type >::type type
cudaError_t ger_batched(cudaStream_t handle, const int m, const int n, const float *alpha, const float *const x[], const int incx, const float *const y[], const int incy, float *const A[], const int lda, const int batch_count)
Xger batched API.
void ger_batched(BLASHandle< PlatformKind::CUDA > &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const int batch_count)
cublasOperation_t convertOperation(const char trans)
BLASHandle(Queue< PlatformKind::CUDA > &queue)
double B(double x, int k, int i, const std::vector< double > &t)
#define cublasErrorCheck(ans, cause)
void gemv(BLASHandle< PlatformKind::CUDA > &handle, const char trans, const int m, const int n, const float &alpha, const float *const A, const int lda, const float *const x, const int incx, const float &beta, float *const y, const int incy)
#define cublasCgemmBatched
#define cublasZgemmBatched