QMCPACK
AccelBLAS_OMPTarget.hpp
Go to the documentation of this file.
1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2024 QMCPACK developers.
6 //
7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8 //////////////////////////////////////////////////////////////////////////////////////
9 
10 #ifndef QMCPLUSPLUS_ACCELBLAS_OMPTARGET_H
11 #define QMCPLUSPLUS_ACCELBLAS_OMPTARGET_H
12 
13 #include "AccelBLASHandle.hpp"
14 #include "QueueOMPTarget.hpp"
15 #include "ompBLAS.hpp"
16 
17 namespace qmcplusplus
18 {
19 namespace compute
20 {
21 template<>
23 {
24 public:
26 
27  BLASHandle(Queue<PlatformKind::OMPTARGET>& queue) : h_ompblas(0) {}
28 };
29 
30 namespace BLAS
31 {
32 
33 template<typename T>
35  const char transa,
36  const char transb,
37  int m,
38  int n,
39  int k,
40  const T& alpha,
41  const T* A,
42  int lda,
43  const T* B,
44  int ldb,
45  const T& beta,
46  T* C,
47  int ldc)
48 {
49  if (ompBLAS::gemm(handle.h_ompblas, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) != 0)
50  throw std::runtime_error("ompBLAS::gemm failed!");
51 }
52 
53 template<typename T>
55  const char transa,
56  const char transb,
57  int m,
58  int n,
59  int k,
60  const T& alpha,
61  const T* const A[],
62  int lda,
63  const T* const B[],
64  int ldb,
65  const T& beta,
66  T* const C[],
67  int ldc,
68  int batchCount)
69 {
70  if (ompBLAS::gemm_batched(handle.h_ompblas, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc,
71  batchCount) != 0)
72  throw std::runtime_error("ompBLAS::gemm_batched failed!");
73 }
74 
75 
76 template<typename T>
78  const char trans,
79  const int m,
80  const int n,
81  const T& alpha,
82  const T* const A,
83  const int lda,
84  const T* const x,
85  const int incx,
86  const T& beta,
87  T* const y,
88  const int incy)
89 {
90  if (ompBLAS::gemv(handle.h_ompblas, trans, m, n, alpha, A, lda, x, incx, beta, y, incy) != 0)
91  throw std::runtime_error("ompBLAS::gemv_batched failed!");
92 }
93 
94 template<typename T>
96  const char trans,
97  const int m,
98  const int n,
99  const T* alpha,
100  const T* const A[],
101  const int lda,
102  const T* const x[],
103  const int incx,
104  const T* beta,
105  T* const y[],
106  const int incy,
107  const int batch_count)
108 {
109  if (ompBLAS::gemv_batched(handle.h_ompblas, trans, m, n, alpha, A, lda, x, incx, beta, y, incy, batch_count) != 0)
110  throw std::runtime_error("ompBLAS::gemv_batched failed!");
111 }
112 
113 template<typename T>
115  const int m,
116  const int n,
117  const T& alpha,
118  const T* const x,
119  const int incx,
120  const T* const y,
121  const int incy,
122  T* const A,
123  const int lda)
124 {
125  if (ompBLAS::ger(handle.h_ompblas, m, n, alpha, x, incx, y, incy, A, lda) != 0)
126  throw std::runtime_error("ompBLAS::ger_batched failed!");
127 }
128 
129 template<typename T>
131  const int m,
132  const int n,
133  const T* alpha,
134  const T* const x[],
135  const int incx,
136  const T* const y[],
137  const int incy,
138  T* const A[],
139  const int lda,
140  const int batch_count)
141 {
142  if (ompBLAS::ger_batched(handle.h_ompblas, m, n, alpha, x, incx, y, incy, A, lda, batch_count) != 0)
143  throw std::runtime_error("ompBLAS::ger_batched failed!");
144 }
145 
146 template<typename T>
148  const int n,
149  const T* const x[],
150  const int incx,
151  T* const y[],
152  const int incy,
153  const int batch_count)
154 {
155  if (ompBLAS::copy_batched(handle.h_ompblas, n, x, incx, y, incy, batch_count) != 0)
156  throw std::runtime_error("ompBLAS::copy_batched failed!");
157 }
158 
159 } // namespace BLAS
160 } // namespace compute
161 } // namespace qmcplusplus
162 #undef castNativeType
163 #endif
void gemm(BLASHandle< PlatformKind::CUDA > &handle, const char transa, const char transb, int m, int n, int k, const float &alpha, const float *A, int lda, const float *B, int ldb, const float &beta, float *C, int ldc)
ompBLAS_status gemm(ompBLAS_handle &handle, const char transa, const char transb, const int M, const int N, const int K, const T &alpha, const T *const A, const int lda, const T *const B, const int ldb, const T &beta, T *const C, const int ldc)
ompBLAS_status gemv(ompBLAS_handle &handle, const char trans, const int m, const int n, const T alpha, const T *const A, const int lda, const T *const x, const int incx, const T beta, T *const y, const int incy)
helper functions for EinsplineSetBuilder
Definition: Configuration.h:43
Interfaces to blas library.
Definition: BLAS.hpp:38
ompBLAS_status gemv_batched(ompBLAS_handle &handle, const char trans, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const int batch_count)
void ger(BLASHandle< PlatformKind::CUDA > &handle, const int m, const int n, const float &alpha, const float *const x, const int incx, const float *const y, const int incy, float *const A, const int lda)
void gemv_batched(BLASHandle< PlatformKind::CUDA > &handle, const char trans, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const int batch_count)
ompBLAS_status gemm_batched(ompBLAS_handle &handle, const char transa, const char transb, const int M, const int N, const int K, const T &alpha, const T *const A[], const int lda, const T *const B[], const int ldb, const T &beta, T *const C[], const int ldc, const int batch_count)
void gemm_batched(BLASHandle< PlatformKind::CUDA > &handle, const char transa, const char transb, int m, int n, int k, const float &alpha, const float *const A[], int lda, const float *const B[], int ldb, const float &beta, float *const C[], int ldc, int batchCount)
ompBLAS_status ger_batched(ompBLAS_handle &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const int batch_count)
void copy_batched(BLASHandle< PlatformKind::CUDA > &handle, const int n, const T *const in[], const int incx, T *const out[], const int incy, const int batch_count)
ompBLAS_status ger(ompBLAS_handle &handle, const int m, const int n, const T alpha, const T *const x, const int incx, const T *const y, const int incy, T *const A, const int lda)
void ger_batched(BLASHandle< PlatformKind::CUDA > &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const int batch_count)
double B(double x, int k, int i, const std::vector< double > &t)
void gemv(BLASHandle< PlatformKind::CUDA > &handle, const char trans, const int m, const int n, const float &alpha, const float *const A, const int lda, const float *const x, const int incx, const float &beta, float *const y, const int incy)
ompBLAS_status copy_batched(ompBLAS_handle &handle, const int n, const T *const x[], const int incx, T *const y[], const int incy, const int batch_count)
copy device data from x to y