QMCPACK
AccelBLAS_CUDA.hpp
Go to the documentation of this file.
1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2024 QMCPACK developers.
6 //
7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8 //////////////////////////////////////////////////////////////////////////////////////
9 
10 #ifndef QMCPLUSPLUS_CUDA_ACCELBLAS_CUDA_H
11 #define QMCPLUSPLUS_CUDA_ACCELBLAS_CUDA_H
12 
13 #include "AccelBLASHandle.hpp"
14 #include "CUDA/CUDAruntime.hpp"
15 #include "CUDA/QueueCUDA.hpp"
16 #include "CUDA/cuBLAS.hpp"
18 
19 #ifndef QMC_CUDA2HIP
20 #define castNativeType castCUDAType
21 #else
22 #define castNativeType casthipblasType
23 #endif
24 
25 namespace qmcplusplus
26 {
27 namespace compute
28 {
29 template<>
31 {
32 public:
33  // cuda stream, not owned, reference-only
35  // cublas handle
37 
38  BLASHandle(Queue<PlatformKind::CUDA>& queue) : h_stream(queue.getNative())
39  {
40  cublasErrorCheck(cublasCreate(&h_cublas), "cublasCreate failed!");
41  cublasErrorCheck(cublasSetStream(h_cublas, h_stream), "cublasSetStream failed!");
42  }
43 
44  ~BLASHandle() { cublasErrorCheck(cublasDestroy(h_cublas), "cublasDestroy failed!"); }
45 };
46 
47 namespace BLAS
48 {
50  const char transa,
51  const char transb,
52  int m,
53  int n,
54  int k,
55  const float& alpha,
56  const float* A,
57  int lda,
58  const float* B,
59  int ldb,
60  const float& beta,
61  float* C,
62  int ldc)
63 {
65  n, k, &alpha, A, lda, B, ldb, &beta, C, ldc),
66  "cublasSgemm failed!");
67 }
68 
70  const char transa,
71  const char transb,
72  int m,
73  int n,
74  int k,
75  const double& alpha,
76  const double* A,
77  int lda,
78  const double* B,
79  int ldb,
80  const double& beta,
81  double* C,
82  int ldc)
83 {
85  n, k, &alpha, A, lda, B, ldb, &beta, C, ldc),
86  "cublasDgemm failed!");
87 }
88 
90  const char transa,
91  const char transb,
92  int m,
93  int n,
94  int k,
95  const std::complex<float>& alpha,
96  const std::complex<float>* A,
97  int lda,
98  const std::complex<float>* B,
99  int ldb,
100  const std::complex<float>& beta,
101  std::complex<float>* C,
102  int ldc)
103 {
105  n, k, castNativeType(&alpha), castNativeType(A), lda, castNativeType(B), ldb,
106  castNativeType(&beta), castNativeType(C), ldc),
107  "cublasCgemm failed!");
108 }
109 
111  const char transa,
112  const char transb,
113  int m,
114  int n,
115  int k,
116  const std::complex<double>& alpha,
117  const std::complex<double>* A,
118  int lda,
119  const std::complex<double>* B,
120  int ldb,
121  const std::complex<double>& beta,
122  std::complex<double>* C,
123  int ldc)
124 {
126  n, k, castNativeType(&alpha), castNativeType(A), lda, castNativeType(B), ldb,
127  castNativeType(&beta), castNativeType(C), ldc),
128  "cublasZgemm failed!");
129 }
130 
132  const char trans,
133  const int m,
134  const int n,
135  const float& alpha,
136  const float* const A,
137  const int lda,
138  const float* const x,
139  const int incx,
140  const float& beta,
141  float* const y,
142  const int incy)
143 {
144  cublasErrorCheck(cublasSgemv(handle.h_cublas, cuBLAS::convertOperation(trans), m, n, &alpha, A, lda, x, incx, &beta,
145  y, incy),
146  "cublasSgemv failed!");
147 }
148 
150  const char trans,
151  const int m,
152  const int n,
153  const double& alpha,
154  const double* const A,
155  const int lda,
156  const double* const x,
157  const int incx,
158  const double& beta,
159  double* const y,
160  const int incy)
161 {
162  cublasErrorCheck(cublasDgemv(handle.h_cublas, cuBLAS::convertOperation(trans), m, n, &alpha, A, lda, x, incx, &beta,
163  y, incy),
164  "cublasDgemv failed!");
165 }
166 
168  const char trans,
169  const int m,
170  const int n,
171  const std::complex<float>& alpha,
172  const std::complex<float>* A,
173  const int lda,
174  const std::complex<float>* x,
175  const int incx,
176  const std::complex<float>& beta,
177  std::complex<float>* y,
178  const int incy)
179 {
181  castNativeType(A), lda, castNativeType(x), incx, castNativeType(&beta),
182  castNativeType(y), incy),
183  "cublasCgemv failed!");
184 }
185 
187  const char trans,
188  const int m,
189  const int n,
190  const std::complex<double>& alpha,
191  const std::complex<double>* A,
192  const int lda,
193  const std::complex<double>* x,
194  const int incx,
195  const std::complex<double>& beta,
196  std::complex<double>* y,
197  const int incy)
198 {
200  castNativeType(A), lda, castNativeType(x), incx, castNativeType(&beta),
201  castNativeType(y), incy),
202  "cublasZgemv failed!");
203 }
204 
205 template<typename T>
207  const char trans,
208  const int m,
209  const int n,
210  const T* alpha,
211  const T* const A[],
212  const int lda,
213  const T* const x[],
214  const int incx,
215  const T* beta,
216  T* const y[],
217  const int incy,
218  const int batch_count)
219 {
220  cudaErrorCheck(cuBLAS_MFs::gemv_batched(handle.h_stream, trans, m, n, alpha, A, lda, x, incx, beta, y, incy,
221  batch_count),
222  "cuBLAS_MFs::gemv_batched failed!");
223 }
224 
226  const int m,
227  const int n,
228  const float& alpha,
229  const float* const x,
230  const int incx,
231  const float* const y,
232  const int incy,
233  float* const A,
234  const int lda)
235 {
236  cublasErrorCheck(cublasSger(handle.h_cublas, m, n, &alpha, x, incx, y, incy, A, lda), "cublasSger failed!");
237 }
238 
240  const int m,
241  const int n,
242  const double& alpha,
243  const double* const x,
244  const int incx,
245  const double* const y,
246  const int incy,
247  double* const A,
248  const int lda)
249 {
250  cublasErrorCheck(cublasDger(handle.h_cublas, m, n, &alpha, x, incx, y, incy, A, lda), "cublasDger failed!");
251 }
252 
254  const int m,
255  const int n,
256  const std::complex<float>& alpha,
257  const std::complex<float>* x,
258  const int incx,
259  const std::complex<float>* y,
260  const int incy,
261  std::complex<float>* A,
262  const int lda)
263 {
265  castNativeType(y), incy, castNativeType(A), lda),
266  "cublasCger failed!");
267 }
268 
270  const int m,
271  const int n,
272  const std::complex<double>& alpha,
273  const std::complex<double>* x,
274  const int incx,
275  const std::complex<double>* y,
276  const int incy,
277  std::complex<double>* A,
278  const int lda)
279 {
281  castNativeType(y), incy, castNativeType(A), lda),
282  "cublasZger failed!");
283 }
284 
285 template<typename T>
287  const int m,
288  const int n,
289  const T* alpha,
290  const T* const x[],
291  const int incx,
292  const T* const y[],
293  const int incy,
294  T* const A[],
295  const int lda,
296  const int batch_count)
297 {
298  cudaErrorCheck(cuBLAS_MFs::ger_batched(handle.h_stream, m, n, alpha, x, incx, y, incy, A, lda, batch_count),
299  "cuBLAS_MFs::ger_batched failed!");
300 }
301 
302 template<typename T>
304  const int n,
305  const T* const in[],
306  const int incx,
307  T* const out[],
308  const int incy,
309  const int batch_count)
310 {
311  cudaErrorCheck(cuBLAS_MFs::copy_batched(handle.h_stream, n, in, incx, out, incy, batch_count),
312  "cuBLAS_MFs::copy_batched failed!");
313 }
314 
316  const char transa,
317  const char transb,
318  int m,
319  int n,
320  int k,
321  const float& alpha,
322  const float* const A[],
323  int lda,
324  const float* const B[],
325  int ldb,
326  const float& beta,
327  float* const C[],
328  int ldc,
329  int batchCount)
330 {
332  cuBLAS::convertOperation(transb), m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc,
333  batchCount),
334  "cublasSgemmBatched failed!");
335 }
336 
338  const char transa,
339  const char transb,
340  int m,
341  int n,
342  int k,
343  const std::complex<float>& alpha,
344  const std::complex<float>* const A[],
345  int lda,
346  const std::complex<float>* const B[],
347  int ldb,
348  const std::complex<float>& beta,
349  std::complex<float>* const C[],
350  int ldc,
351  int batchCount)
352 {
353  // This is necessary to not break the complex CUDA type mapping semantics while
354  // dealing with the const cuComplex * A[] style API of cuBLAS
355  // C++ makes you jump through some hoops to remove the bottom const on a double pointer.
356  // see typetraits/type_manipulation.hpp
357  auto non_const_A = const_cast<BottomConstRemoved<decltype(A)>::type>(A);
358  auto non_const_B = const_cast<BottomConstRemoved<decltype(B)>::type>(B);
359  auto non_const_C = const_cast<BottomConstRemoved<decltype(C)>::type>(C);
360 
362  cuBLAS::convertOperation(transb), m, n, k, castNativeType(&alpha),
363  castNativeType(non_const_A), lda, castNativeType(non_const_B), ldb,
364  castNativeType(&beta), castNativeType(non_const_C), ldc, batchCount),
365  "cublasCgemmBatched failed!");
366 }
367 
369  const char transa,
370  const char transb,
371  int m,
372  int n,
373  int k,
374  const double& alpha,
375  const double* const A[],
376  int lda,
377  const double* const B[],
378  int ldb,
379  const double& beta,
380  double* const C[],
381  int ldc,
382  int batchCount)
383 {
385  cuBLAS::convertOperation(transb), m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc,
386  batchCount),
387  "cublasDgemmBatched failed!");
388 }
389 
391  const char transa,
392  const char transb,
393  int m,
394  int n,
395  int k,
396  const std::complex<double>& alpha,
397  const std::complex<double>* const A[],
398  int lda,
399  const std::complex<double>* const B[],
400  int ldb,
401  const std::complex<double>& beta,
402  std::complex<double>* const C[],
403  int ldc,
404  int batchCount)
405 {
406  auto non_const_A = const_cast<BottomConstRemoved<decltype(A)>::type>(A);
407  auto non_const_B = const_cast<BottomConstRemoved<decltype(B)>::type>(B);
408  auto non_const_C = const_cast<BottomConstRemoved<decltype(C)>::type>(C);
409 
411  cuBLAS::convertOperation(transb), m, n, k, castNativeType(&alpha),
412  castNativeType(non_const_A), lda, castNativeType(non_const_B), ldb,
413  castNativeType(&beta), castNativeType(non_const_C), ldc, batchCount),
414  "cublasZgemmBatched failed!");
415 }
416 
417 } // namespace BLAS
418 } // namespace compute
419 } // namespace qmcplusplus
420 #undef castNativeType
421 #endif
void gemm(BLASHandle< PlatformKind::CUDA > &handle, const char transa, const char transb, int m, int n, int k, const float &alpha, const float *A, int lda, const float *B, int ldb, const float &beta, float *C, int ldc)
#define cublasDgemmBatched
Definition: cuda2hip.h:53
helper functions for EinsplineSetBuilder
Definition: Configuration.h:43
#define cublasCgemm
Definition: cuda2hip.h:45
handle CUDA/HIP runtime selection.
#define cublasDgemm
Definition: cuda2hip.h:52
Interfaces to blas library.
Definition: BLAS.hpp:38
#define cublasSgemmBatched
Definition: cuda2hip.h:60
cudaError_t gemv_batched(cudaStream_t handle, const char trans, const int m, const int n, const float *alpha, const float *const A[], const int lda, const float *const x[], const int incx, const float *beta, float *const y[], const int incy, const int batch_count)
Xgemv batched API.
#define cublasZgeru
Definition: cuda2hip.h:65
#define cublasDestroy
Definition: cuda2hip.h:38
cudaError_t copy_batched(cudaStream_t hstream, const int n, const float *const in[], const int incx, float *const out[], const int incy, const int batch_count)
Xcopy batched API.
#define cudaStream_t
Definition: cuda2hip.h:149
void ger(BLASHandle< PlatformKind::CUDA > &handle, const int m, const int n, const float &alpha, const float *const x, const int incx, const float *const y, const int incy, float *const A, const int lda)
void gemv_batched(BLASHandle< PlatformKind::CUDA > &handle, const char trans, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const int batch_count)
cudaErrorCheck(cudaMemcpyAsync(dev_lu.data(), lu.data(), sizeof(decltype(lu)::value_type) *lu.size(), cudaMemcpyHostToDevice, hstream), "cudaMemcpyAsync failed copying log_values to device")
void gemm_batched(BLASHandle< PlatformKind::CUDA > &handle, const char transa, const char transb, int m, int n, int k, const float &alpha, const float *const A[], int lda, const float *const B[], int ldb, const float &beta, float *const C[], int ldc, int batchCount)
#define cublasZgemv
Definition: cuda2hip.h:64
#define cublasCgemv
Definition: cuda2hip.h:43
#define cublasSger
Definition: cuda2hip.h:58
#define castNativeType
#define cublasCreate
Definition: cuda2hip.h:37
#define cublasSetStream
Definition: cuda2hip.h:39
void copy_batched(BLASHandle< PlatformKind::CUDA > &handle, const int n, const T *const in[], const int incx, T *const out[], const int incy, const int batch_count)
typename std::add_pointer< typename std::remove_const< typename std::remove_pointer< CT >::type >::type >::type type
#define cublasSgemv
Definition: cuda2hip.h:57
#define cublasDgemv
Definition: cuda2hip.h:50
cudaError_t ger_batched(cudaStream_t handle, const int m, const int n, const float *alpha, const float *const x[], const int incx, const float *const y[], const int incy, float *const A[], const int lda, const int batch_count)
Xger batched API.
void ger_batched(BLASHandle< PlatformKind::CUDA > &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const int batch_count)
cublasOperation_t convertOperation(const char trans)
Definition: cuBLAS.hpp:96
BLASHandle(Queue< PlatformKind::CUDA > &queue)
double B(double x, int k, int i, const std::vector< double > &t)
#define cublasErrorCheck(ans, cause)
Definition: cuBLAS.hpp:34
void gemv(BLASHandle< PlatformKind::CUDA > &handle, const char trans, const int m, const int n, const float &alpha, const float *const A, const int lda, const float *const x, const int incx, const float &beta, float *const y, const int incy)
#define cublasCgemmBatched
Definition: cuda2hip.h:46
#define cublasDger
Definition: cuda2hip.h:51
#define cublasZgemm
Definition: cuda2hip.h:66
#define cublasZgemmBatched
Definition: cuda2hip.h:67
#define cublasSgemm
Definition: cuda2hip.h:59
#define cublasCgeru
Definition: cuda2hip.h:44
#define cublasHandle_t
Definition: cuda2hip.h:35