QMCPACK
cuBLAS.hpp
Go to the documentation of this file.
1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2021 QMCPACK developers.
6 //
7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8 // Peter Doak, doakpw@ornl.gov, Oak Ridge National Laboratory
9 //
10 // File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
11 //////////////////////////////////////////////////////////////////////////////////////
12 
13 #ifndef QMCPLUSPLUS_CUBLAS_H
14 #define QMCPLUSPLUS_CUBLAS_H
15 
16 #include <complex>
17 #include <iostream>
18 #include <string>
19 #include <stdexcept>
20 #include "config.h"
21 #ifndef QMC_CUDA2HIP
22 #include <cublas_v2.h>
23 #define castNativeType castCUDAType
24 #else
25 #include <hipblas/hipblas.h>
29 #define castNativeType casthipblasType
30 #endif
31 #include "CUDATypeMapping.hpp"
33 
34 #define cublasErrorCheck(ans, cause) \
35  { \
36  cublasAssert((ans), cause, __FILE__, __LINE__); \
37  }
38 
39 /// prints cuBLAS error messages. Always use cublasErrorCheck macro.
40 inline void cublasAssert(cublasStatus_t code, const std::string& cause, const char* file, int line, bool abort = true)
41 {
42  if (code != CUBLAS_STATUS_SUCCESS)
43  {
44  std::string cublas_error;
45  switch (code)
46  {
48  cublas_error = "CUBLAS_STATUS_NOT_INITIALIZED";
49  break;
51  cublas_error = "CUBLAS_STATUS_ALLOC_FAILED";
52  break;
54  cublas_error = "CUBLAS_STATUS_INVALID_VALUE";
55  break;
57  cublas_error = "CUBLAS_STATUS_ARCH_MISMATCH";
58  break;
60  cublas_error = "CUBLAS_STATUS_MAPPING_ERROR";
61  break;
63  cublas_error = "CUBLAS_STATUS_EXECUTION_FAILED";
64  break;
66  cublas_error = "CUBLAS_STATUS_INTERNAL_ERROR";
67  break;
69  cublas_error = "CUBLAS_STATUS_NOT_SUPPORTED";
70  break;
71 #ifndef QMC_CUDA2HIP
72  case CUBLAS_STATUS_LICENSE_ERROR:
73  cublas_error = "CUBLAS_STATUS_LICENSE_ERROR";
74  break;
75 #endif
76  default:
77  cublas_error = "<unknown>";
78  }
79 
80  std::ostringstream err;
81  err << "cublasAssert: " << cublas_error << ", file " << file << " , line " << line << std::endl
82  << cause << std::endl;
83  std::cerr << err.str();
84  //if (abort) exit(code);
85  throw std::runtime_error(cause);
86  }
87 }
88 
89 namespace qmcplusplus
90 {
91 /** interface to cuBLAS calls for different data types S/C/D/Z
92  */
93 namespace cuBLAS
94 {
95 
96 inline cublasOperation_t convertOperation(const char trans)
97 {
98  if (trans == 'N' || trans == 'n')
99  return CUBLAS_OP_N;
100  else if (trans == 'T' || trans == 't')
101  return CUBLAS_OP_T;
102  else if (trans == 'C' || trans == 'c')
103  return CUBLAS_OP_C;
104  else
105  throw std::runtime_error(
106  "cuBLAS::convertOperation trans can only be 'N', 'T', 'C', 'n', 't', 'c'. Input value is " +
107  std::string(1, trans));
108 }
109 
111  cublasOperation_t& transa,
112  cublasOperation_t& transb,
113  int m,
114  int n,
115  const float* alpha,
116  const float* A,
117  int lda,
118  const float* beta,
119  const float* B,
120  int ldb,
121  float* C,
122  int ldc)
123 {
124  return cublasSgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
125 }
126 
128  cublasOperation_t transa,
129  cublasOperation_t transb,
130  int m,
131  int n,
132  const double* alpha,
133  const double* A,
134  int lda,
135  const double* beta,
136  const double* B,
137  int ldb,
138  double* C,
139  int ldc)
140 {
141  return cublasDgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
142 }
143 
145  cublasOperation_t transa,
146  cublasOperation_t transb,
147  int m,
148  int n,
149  const std::complex<double>* alpha,
150  const std::complex<double>* A,
151  int lda,
152  const std::complex<double>* beta,
153  const std::complex<double>* B,
154  int ldb,
155  std::complex<double>* C,
156  int ldc)
157 {
158  return cublasZgeam(handle, transa, transb, m, n, castNativeType(alpha), castNativeType(A), lda, castNativeType(beta),
159  castNativeType(B), ldb, castNativeType(C), ldc);
160 }
161 
163  cublasOperation_t transa,
164  cublasOperation_t transb,
165  int m,
166  int n,
167  const std::complex<float>* alpha,
168  const std::complex<float>* A,
169  int lda,
170  const std::complex<float>* beta,
171  const std::complex<float>* B,
172  int ldb,
173  std::complex<float>* C,
174  int ldc)
175 {
176  return cublasCgeam(handle, transa, transb, m, n, castNativeType(alpha), castNativeType(A), lda, castNativeType(beta),
177  castNativeType(B), ldb, castNativeType(C), ldc);
178 }
179 
181  int n,
182  float* A[],
183  int lda,
184  int* PivotArray,
185  int* infoArray,
186  int batchSize)
187 {
188  return cublasSgetrfBatched(handle, n, A, lda, PivotArray, infoArray, batchSize);
189 }
190 
192  int n,
193  double* A[],
194  int lda,
195  int* PivotArray,
196  int* infoArray,
197  int batchSize)
198 {
199  return cublasDgetrfBatched(handle, n, A, lda, PivotArray, infoArray, batchSize);
200 }
201 
203  int n,
204  std::complex<float>* A[],
205  int lda,
206  int* PivotArray,
207  int* infoArray,
208  int batchSize)
209 {
210  return cublasCgetrfBatched(handle, n, castCUDAType(A), lda, PivotArray, infoArray, batchSize);
211 }
212 
214  int n,
215  std::complex<double>* A[],
216  int lda,
217  int* PivotArray,
218  int* infoArray,
219  int batchSize)
220 {
221  return cublasZgetrfBatched(handle, n, castCUDAType(A), lda, PivotArray, infoArray, batchSize);
222 }
223 
225  int n,
226  float* A[],
227  int lda,
228  int* PivotArray,
229  float* C[],
230  int ldc,
231  int* infoArray,
232  int batchSize)
233 {
234  return cublasSgetriBatched(handle, n, A, lda, PivotArray, C, ldc, infoArray, batchSize);
235 }
236 
238  int n,
239  double* A[],
240  int lda,
241  int* PivotArray,
242  double* C[],
243  int ldc,
244  int* infoArray,
245  int batchSize)
246 {
247  return cublasDgetriBatched(handle, n, A, lda, PivotArray, C, ldc, infoArray, batchSize);
248 }
249 
251  int n,
252  std::complex<float>* A[],
253  int lda,
254  int* PivotArray,
255  std::complex<float>* C[],
256  int ldc,
257  int* infoArray,
258  int batchSize)
259 {
260  return cublasCgetriBatched(handle, n, castCUDAType(A), lda, PivotArray, castCUDAType(C), ldc, infoArray, batchSize);
261 }
262 
264  int n,
265  std::complex<double>* A[],
266  int lda,
267  int* PivotArray,
268  std::complex<double>* C[],
269  int ldc,
270  int* infoArray,
271  int batchSize)
272 {
273  return cublasZgetriBatched(handle, n, castCUDAType(A), lda, PivotArray, castCUDAType(C), ldc, infoArray, batchSize);
274 }
275 
276 }; // namespace cuBLAS
277 
278 } // namespace qmcplusplus
279 #undef castNativeType
280 #endif // QMCPLUSPLUS_CUBLAS_H
#define cublasCgeam
Definition: cuda2hip.h:42
#define CUBLAS_OP_N
Definition: cuda2hip.h:19
#define cublasCgetriBatched
Definition: cuda2hip.h:48
helper functions for EinsplineSetBuilder
Definition: Configuration.h:43
#define CUBLAS_STATUS_INVALID_VALUE
Definition: cuda2hip.h:26
cublasStatus_t getrf_batched(cublasHandle_t &handle, int n, float *A[], int lda, int *PivotArray, int *infoArray, int batchSize)
Definition: cuBLAS.hpp:180
#define CUBLAS_STATUS_ALLOC_FAILED
Definition: cuda2hip.h:22
#define cublasDgetriBatched
Definition: cuda2hip.h:55
#define cublasSgetrfBatched
Definition: cuda2hip.h:61
#define CUBLAS_STATUS_SUCCESS
Definition: cuda2hip.h:31
#define cublasCgetrfBatched
Definition: cuda2hip.h:47
#define castNativeType
Definition: cuBLAS.hpp:23
#define CUBLAS_STATUS_EXECUTION_FAILED
Definition: cuda2hip.h:24
#define CUBLAS_STATUS_MAPPING_ERROR
Definition: cuda2hip.h:28
#define cublasZgeam
Definition: cuda2hip.h:63
#define cublasOperation_t
Definition: cuda2hip.h:41
#define cublasStatus_t
Definition: cuda2hip.h:36
#define CUBLAS_STATUS_NOT_SUPPORTED
Definition: cuda2hip.h:30
void cublasAssert(cublasStatus_t code, const std::string &cause, const char *file, int line, bool abort=true)
prints cuBLAS error messages. Always use cublasErrorCheck macro.
Definition: cuBLAS.hpp:40
#define CUBLAS_OP_C
Definition: cuda2hip.h:21
#define CUBLAS_STATUS_NOT_INITIALIZED
Definition: cuda2hip.h:29
#define cublasSgeam
Definition: cuda2hip.h:56
#define CUBLAS_OP_T
Definition: cuda2hip.h:20
#define cublasZgetriBatched
Definition: cuda2hip.h:69
CUDATypeMap< T > castCUDAType(T var)
#define cublasZgetrfBatched
Definition: cuda2hip.h:68
cublasStatus_t getri_batched(cublasHandle_t &handle, int n, float *A[], int lda, int *PivotArray, float *C[], int ldc, int *infoArray, int batchSize)
Definition: cuBLAS.hpp:224
#define cublasDgeam
Definition: cuda2hip.h:49
cublasStatus_t geam(cublasHandle_t &handle, cublasOperation_t &transa, cublasOperation_t &transb, int m, int n, const float *alpha, const float *A, int lda, const float *beta, const float *B, int ldb, float *C, int ldc)
Definition: cuBLAS.hpp:110
#define cublasDgetrfBatched
Definition: cuda2hip.h:54
#define CUBLAS_STATUS_INTERNAL_ERROR
Definition: cuda2hip.h:25
cublasOperation_t convertOperation(const char trans)
Definition: cuBLAS.hpp:96
double B(double x, int k, int i, const std::vector< double > &t)
#define CUBLAS_STATUS_ARCH_MISMATCH
Definition: cuda2hip.h:23
#define cublasSgetriBatched
Definition: cuda2hip.h:62
#define cublasHandle_t
Definition: cuda2hip.h:35