QMCPACK
AccelBLAS_SYCL.hpp
Go to the documentation of this file.
1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2024 QMCPACK developers.
6 //
7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8 //////////////////////////////////////////////////////////////////////////////////////
9 
10 #ifndef QMCPLUSPLUS_SYCL_ACCELBLAS_SYCL_H
11 #define QMCPLUSPLUS_SYCL_ACCELBLAS_SYCL_H
12 
13 #include "AccelBLASHandle.hpp"
14 #include "SYCL/QueueSYCL.hpp"
15 #include "SYCL/syclBLAS.hpp"
16 
17 namespace qmcplusplus
18 {
19 namespace compute
20 {
21 template<>
23 {
24 public:
25  BLASHandle(Queue<PlatformKind::SYCL>& queue) : queue_(queue.getNative()) {}
26  // sycl queue, not owned, reference-only
28 };
29 
30 namespace BLAS
31 {
32 template<typename T>
34  const char transa,
35  const char transb,
36  int m,
37  int n,
38  int k,
39  const T& alpha,
40  const T* A,
41  int lda,
42  const T* B,
43  int ldb,
44  const T& beta,
45  T* C,
46  int ldc)
47 {
48  try
49  {
51  k, alpha, A, lda, B, ldb, beta, C, ldc);
52  }
53  catch (oneapi::mkl::exception& e)
54  {
55  throw std::runtime_error(std::string("AccelBLAS::gemm exception: ") + e.what());
56  }
57 }
58 
59 template<typename T>
61  const char trans,
62  const int m,
63  const int n,
64  const T& alpha,
65  const T* const A,
66  const int lda,
67  const T* const x,
68  const int incx,
69  const T& beta,
70  T* const y,
71  const int incy)
72 {
73  try
74  {
75  oneapi::mkl::blas::gemv(handle.queue_, syclBLAS::convertTransEnum(trans), m, n, alpha, A, lda, x, incx, beta, y,
76  incy);
77  }
78  catch (oneapi::mkl::exception& e)
79  {
80  throw std::runtime_error(std::string("AccelBLAS::gemv exception: ") + e.what());
81  }
82 }
83 
84 template<typename T>
86  const char trans,
87  const int m,
88  const int n,
89  const T* alpha,
90  const T* const A[],
91  const int lda,
92  const T* const x[],
93  const int incx,
94  const T* beta,
95  T* const y[],
96  const int incy,
97  const size_t batch_count)
98 {
99  try
100  { // calling makeshift version for now due to the lack of vendor optimized versions
101  syclBLAS::gemv_batched(handle.queue_, trans, m, n, alpha, A, lda, x, incx, beta, y, incy, batch_count);
102  }
103  catch (sycl::exception& e)
104  {
105  throw std::runtime_error(std::string("AccelBLAS::gemv_batch exception: ") + e.what());
106  }
107 }
108 
109 template<typename T>
111  const int m,
112  const int n,
113  const T& alpha,
114  const T* const x,
115  const int incx,
116  const T* const y,
117  const int incy,
118  T* const A,
119  const int lda)
120 {
121  try
122  {
123  oneapi::mkl::blas::ger(handle.queue_, m, n, alpha, x, incx, y, incy, A, lda);
124  }
125  catch (oneapi::mkl::exception& e)
126  {
127  throw std::runtime_error(std::string("AccelBLAS::ger exception: ") + e.what());
128  }
129 }
130 
131 template<typename T>
133  const int m,
134  const int n,
135  const T* alpha,
136  const T* const x[],
137  const int incx,
138  const T* const y[],
139  const int incy,
140  T* const A[],
141  const int lda,
142  const size_t batch_count)
143 {
144  try
145  { // calling makeshift version for now due to the lack of vendor optimized versions
146  syclBLAS::ger_batched(handle.queue_, m, n, alpha, x, incx, y, incy, A, lda, batch_count);
147  }
148  catch (sycl::exception& e)
149  {
150  throw std::runtime_error(std::string("AccelBLAS::ger_batched exception: ") + e.what());
151  }
152 }
153 
154 template<typename T>
157  const T* const in[],
159  T* const out[],
161  const size_t batch_count)
162 {
163  try
164  {
165  syclBLAS::syclBLAS_int bc = batch_count;
166  oneapi::mkl::blas::copy_batch(handle.queue_, &n, const_cast<const T**>(in), &incx, const_cast<T**>(out), &incy, 1,
167  &bc);
168  }
169  catch (oneapi::mkl::exception& e)
170  {
171  throw std::runtime_error(std::string("AccelBLAS::copy_batch exception: ") + e.what());
172  }
173 }
174 
175 template<typename T>
177  const char transa,
178  const char transb,
182  const T& alpha,
183  const T* const A[],
185  const T* const B[],
187  const T& beta,
188  T* const C[],
190  const size_t batch_count)
191 {
192  auto trans_a = syclBLAS::convertTransEnum(transa);
193  auto trans_b = syclBLAS::convertTransEnum(transb);
194  try
195  {
196 #if defined(GEMM_BATCH_SPAN)
197  sycl::span alpha_span(sycl::malloc_shared<T>(1, handle.queue_), 1);
198  alpha_span[0] = alpha;
199  sycl::span beta_span(sycl::malloc_shared<T>(1, handle.queue_), 1);
200  beta_span[0] = beta;
201 
202  oneapi::mkl::blas::gemm_batch(handle.queue_, sycl::span{&trans_a, 1}, sycl::span{&trans_b, 1}, sycl::span{&m, 1},
203  sycl::span{&n, 1}, sycl::span{&k, 1}, alpha_span,
204  sycl::span{const_cast<const T**>(A), batch_count}, sycl::span{&lda, 1},
205  sycl::span{const_cast<const T**>(B), batch_count}, sycl::span{&ldb, 1}, beta_span,
206  sycl::span{const_cast<T**>(C), batch_count}, sycl::span{&ldc, 1}, 1,
207  sycl::span{const_cast<size_t*>(&batch_count), 1});
208  sycl::free(alpha_span.data(), handle.queue_);
209  sycl::free(beta_span.data(), handle.queue_);
210 #else
211  syclBLAS::syclBLAS_int bc = batch_count;
212  oneapi::mkl::blas::gemm_batch(handle.queue_, &trans_a, &trans_b, &m, &n, &k, const_cast<const T*>(&alpha),
213  const_cast<const T**>(A), &lda, const_cast<const T**>(B), &ldb,
214  const_cast<const T*>(&beta), const_cast<T**>(C), &ldc, 1, &bc);
215 #endif
216  }
217  catch (oneapi::mkl::exception& e)
218  {
219  throw std::runtime_error(std::string("AccelBLAS::gemm_batched exception: ") + e.what());
220  }
221 }
222 
223 } // namespace BLAS
224 } // namespace compute
225 } // namespace qmcplusplus
226 #undef castNativeType
227 #endif
void gemm(BLASHandle< PlatformKind::CUDA > &handle, const char transa, const char transb, int m, int n, int k, const float &alpha, const float *A, int lda, const float *B, int ldb, const float &beta, float *C, int ldc)
helper functions for EinsplineSetBuilder
Definition: Configuration.h:43
sycl::event ger_batched(sycl::queue &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const size_t batch_count, const std::vector< sycl::event > &events={})
in-house version of ger_batch implemented in SYCL. Can be dropped if we have vendor optimized version...
oneapi::mkl::transpose convertTransEnum(char trans)
Definition: syclBLAS.hpp:28
Interfaces to blas library.
Definition: BLAS.hpp:38
void ger(BLASHandle< PlatformKind::CUDA > &handle, const int m, const int n, const float &alpha, const float *const x, const int incx, const float *const y, const int incy, float *const A, const int lda)
void gemv_batched(BLASHandle< PlatformKind::CUDA > &handle, const char trans, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const int batch_count)
sycl::event gemv_batched(sycl::queue &handle, const char trans, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const size_t batch_count, const std::vector< sycl::event > &events={})
in-house version of gemv_batch implemented in SYCL. Can be dropped if we have vendor optimized versio...
void gemm_batched(BLASHandle< PlatformKind::CUDA > &handle, const char transa, const char transb, int m, int n, int k, const float &alpha, const float *const A[], int lda, const float *const B[], int ldb, const float &beta, float *const C[], int ldc, int batchCount)
void gemv(BLASHandle< PlatformKind::SYCL > &handle, const char trans, const int m, const int n, const T &alpha, const T *const A, const int lda, const T *const x, const int incx, const T &beta, T *const y, const int incy)
void copy_batched(BLASHandle< PlatformKind::CUDA > &handle, const int n, const T *const in[], const int incx, T *const out[], const int incy, const int batch_count)
void gemm(BLASHandle< PlatformKind::SYCL > &handle, const char transa, const char transb, int m, int n, int k, const T &alpha, const T *A, int lda, const T *B, int ldb, const T &beta, T *C, int ldc)
void ger(BLASHandle< PlatformKind::SYCL > &handle, const int m, const int n, const T &alpha, const T *const x, const int incx, const T *const y, const int incy, T *const A, const int lda)
BLASHandle(Queue< PlatformKind::SYCL > &queue)
void ger_batched(BLASHandle< PlatformKind::CUDA > &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const int batch_count)
double B(double x, int k, int i, const std::vector< double > &t)
void gemv(BLASHandle< PlatformKind::CUDA > &handle, const char trans, const int m, const int n, const float &alpha, const float *const A, const int lda, const float *const x, const int incx, const float &beta, float *const y, const int incy)
std::int64_t syclBLAS_int
Definition: syclBLAS.hpp:24