QMCPACK
cuBLAS_missing_functions.hpp
Go to the documentation of this file.
1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2020 QMCPACK developers.
6 //
7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8 //
9 // File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
10 //////////////////////////////////////////////////////////////////////////////////////
11 
12 
13 #ifndef QMCPLUSPLUS_CUBLAS_MISSING_FUNCTIONS_H
14 #define QMCPLUSPLUS_CUBLAS_MISSING_FUNCTIONS_H
15 
16 #include <complex>
17 #include "config.h"
18 #include "CUDAruntime.hpp"
19 
20 namespace qmcplusplus
21 {
22 /** Implement selected batched BLAS1/2 calls using CUDA for different data types S/C/D/Z.
23  * cuBLAS_MFs stands for missing functions in cuBLAS.
24  * 1) column major just like the BLAS fortran API
25  * 2) all the functions are asynchronous
26  * 3) all the pointer arguments are expected as device pointers.
27  * 4) in batched APIs, alpha and beta are **not** scalars but pointers to array of batch size.
28  */
29 namespace cuBLAS_MFs
30 {
31 // BLAS2
32 /** Xgemv batched API
33  * @param handle handle for asynchronous computation
34  * @param trans whether A matrices are transposed
35  * @param m number of rows in A
36  * @param n number of columns in A
37  * @param alpha the factor vector of A
38  * @param A device array of device pointers of matrices
39  * @param lda leading dimension of A
40  * @param x device array of device pointers of vector
41  * @param incx increment for the elements of x. It cannot be zero.
42  * @param beta the factor vector of vector y
43  * @param y device array of device pointers of vector
44  * @param incy increment for the elements of y. It cannot be zero.
45  * @param batch_count batch size
46  */
48  const char trans,
49  const int m,
50  const int n,
51  const float* alpha,
52  const float* const A[],
53  const int lda,
54  const float* const x[],
55  const int incx,
56  const float* beta,
57  float* const y[],
58  const int incy,
59  const int batch_count);
60 
62  const char trans,
63  const int m,
64  const int n,
65  const double* alpha,
66  const double* const A[],
67  const int lda,
68  const double* const x[],
69  const int incx,
70  const double* beta,
71  double* const y[],
72  const int incy,
73  const int batch_count);
74 
76  const char trans,
77  const int m,
78  const int n,
79  const std::complex<float>* alpha,
80  const std::complex<float>* const A[],
81  const int lda,
82  const std::complex<float>* const x[],
83  const int incx,
84  const std::complex<float>* beta,
85  std::complex<float>* const y[],
86  const int incy,
87  const int batch_count);
88 
90  const char trans,
91  const int m,
92  const int n,
93  const std::complex<double>* alpha,
94  const std::complex<double>* const A[],
95  const int lda,
96  const std::complex<double>* const x[],
97  const int incx,
98  const std::complex<double>* beta,
99  std::complex<double>* const y[],
100  const int incy,
101  const int batch_count);
102 
103 /** Xger batched API
104  * @param handle handle for asynchronous computation
105  * @param m number of rows in A
106  * @param n number of columns in A
107  * @param alpha the factor vector of A
108  * @param x device array of device pointers of vector
109  * @param incx increment for the elements of x. It cannot be zero.
110  * @param y device array of device pointers of vector
111  * @param incy increment for the elements of y. It cannot be zero.
112  * @param A device array of device pointers of matrices
113  * @param lda leading dimension of A
114  * @param batch_count batch size
115  */
117  const int m,
118  const int n,
119  const float* alpha,
120  const float* const x[],
121  const int incx,
122  const float* const y[],
123  const int incy,
124  float* const A[],
125  const int lda,
126  const int batch_count);
127 
129  const int m,
130  const int n,
131  const double* alpha,
132  const double* const x[],
133  const int incx,
134  const double* const y[],
135  const int incy,
136  double* const A[],
137  const int lda,
138  const int batch_count);
139 
141  const int m,
142  const int n,
143  const std::complex<float>* alpha,
144  const std::complex<float>* const x[],
145  const int incx,
146  const std::complex<float>* const y[],
147  const int incy,
148  std::complex<float>* const A[],
149  const int lda,
150  const int batch_count);
151 
153  const int m,
154  const int n,
155  const std::complex<double>* alpha,
156  const std::complex<double>* const x[],
157  const int incx,
158  const std::complex<double>* const y[],
159  const int incy,
160  std::complex<double>* const A[],
161  const int lda,
162  const int batch_count);
163 
164 // BLAS1
165 /** Xcopy batched API
166  * @param handle handle for asynchronous computation
167  * @param n number of elements to be copied
168  * @param in device array of device pointers of vector
169  * @param incx increment for the elements of in. It cannot be zero.
170  * @param out device array of device pointers of vector
171  * @param incy increment for the elements of out. It cannot be zero.
172  * @param batch_count batch size
173  */
175  const int n,
176  const float* const in[],
177  const int incx,
178  float* const out[],
179  const int incy,
180  const int batch_count);
181 
183  const int n,
184  const double* const in[],
185  const int incx,
186  double* const out[],
187  const int incy,
188  const int batch_count);
189 
191  const int n,
192  const std::complex<float>* const in[],
193  const int incx,
194  std::complex<float>* const out[],
195  const int incy,
196  const int batch_count);
197 
199  const int n,
200  const std::complex<double>* const in[],
201  const int incx,
202  std::complex<double>* const out[],
203  const int incy,
204  const int batch_count);
205 
206 } // namespace cuBLAS_MFs
207 
208 } // namespace qmcplusplus
209 #endif // QMCPLUSPLUS_CUBLAS_INHOUSE_H
helper functions for EinsplineSetBuilder
Definition: Configuration.h:43
handle CUDA/HIP runtime selection.
cudaError_t gemv_batched(cudaStream_t handle, const char trans, const int m, const int n, const float *alpha, const float *const A[], const int lda, const float *const x[], const int incx, const float *beta, float *const y[], const int incy, const int batch_count)
Xgemv batched API.
cudaError_t copy_batched(cudaStream_t hstream, const int n, const float *const in[], const int incx, float *const out[], const int incy, const int batch_count)
Xcopy batched API.
#define cudaError_t
Definition: cuda2hip.h:89
#define cudaStream_t
Definition: cuda2hip.h:149
cudaError_t ger_batched(cudaStream_t handle, const int m, const int n, const float *alpha, const float *const x[], const int incx, const float *const y[], const int incy, float *const A[], const int lda, const int batch_count)
Xger batched API.