10 #ifndef QMCPLUSPLUS_SYCL_ACCELBLAS_SYCL_H 11 #define QMCPLUSPLUS_SYCL_ACCELBLAS_SYCL_H 51 k, alpha,
A,
lda,
B, ldb, beta,
C, ldc);
53 catch (oneapi::mkl::exception&
e)
55 throw std::runtime_error(std::string(
"AccelBLAS::gemm exception: ") +
e.what());
75 oneapi::mkl::blas::gemv(handle.
queue_,
syclBLAS::convertTransEnum(trans),
m,
n, alpha,
A,
lda, x, incx, beta, y,
78 catch (oneapi::mkl::exception&
e)
80 throw std::runtime_error(std::string(
"AccelBLAS::gemv exception: ") +
e.what());
97 const size_t batch_count)
101 syclBLAS::gemv_batched(handle.
queue_, trans,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy, batch_count);
103 catch (sycl::exception&
e)
105 throw std::runtime_error(std::string(
"AccelBLAS::gemv_batch exception: ") +
e.what());
125 catch (oneapi::mkl::exception&
e)
127 throw std::runtime_error(std::string(
"AccelBLAS::ger exception: ") +
e.what());
142 const size_t batch_count)
146 syclBLAS::ger_batched(handle.
queue_,
m,
n, alpha, x, incx, y, incy,
A,
lda, batch_count);
148 catch (sycl::exception&
e)
150 throw std::runtime_error(std::string(
"AccelBLAS::ger_batched exception: ") +
e.what());
161 const size_t batch_count)
166 oneapi::mkl::blas::copy_batch(handle.
queue_, &
n, const_cast<const T**>(in), &incx, const_cast<T**>(out), &incy, 1,
169 catch (oneapi::mkl::exception&
e)
171 throw std::runtime_error(std::string(
"AccelBLAS::copy_batch exception: ") +
e.what());
190 const size_t batch_count)
196 #if defined(GEMM_BATCH_SPAN) 197 sycl::span alpha_span(sycl::malloc_shared<T>(1, handle.
queue_), 1);
198 alpha_span[0] = alpha;
199 sycl::span beta_span(sycl::malloc_shared<T>(1, handle.
queue_), 1);
202 oneapi::mkl::blas::gemm_batch(handle.
queue_, sycl::span{&trans_a, 1}, sycl::span{&trans_b, 1}, sycl::span{&m, 1},
203 sycl::span{&n, 1}, sycl::span{&k, 1}, alpha_span,
204 sycl::span{const_cast<const T**>(A), batch_count}, sycl::span{&lda, 1},
205 sycl::span{const_cast<const T**>(B), batch_count}, sycl::span{&ldb, 1}, beta_span,
206 sycl::span{const_cast<T**>(C), batch_count}, sycl::span{&ldc, 1}, 1,
207 sycl::span{const_cast<size_t*>(&batch_count), 1});
208 sycl::free(alpha_span.data(), handle.
queue_);
209 sycl::free(beta_span.data(), handle.
queue_);
212 oneapi::mkl::blas::gemm_batch(handle.
queue_, &trans_a, &trans_b, &
m, &
n, &k, const_cast<const T*>(&alpha),
213 const_cast<const T**>(
A), &
lda, const_cast<const T**>(
B), &ldb,
214 const_cast<const T*>(&beta), const_cast<T**>(
C), &ldc, 1, &bc);
217 catch (oneapi::mkl::exception&
e)
219 throw std::runtime_error(std::string(
"AccelBLAS::gemm_batched exception: ") +
e.what());
226 #undef castNativeType void gemm(BLASHandle< PlatformKind::CUDA > &handle, const char transa, const char transb, int m, int n, int k, const float &alpha, const float *A, int lda, const float *B, int ldb, const float &beta, float *C, int ldc)
helper functions for EinsplineSetBuilder
sycl::event ger_batched(sycl::queue &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const size_t batch_count, const std::vector< sycl::event > &events={})
in-house version of ger_batch implemented in SYCL. Can be dropped if we have vendor optimized version...
oneapi::mkl::transpose convertTransEnum(char trans)
Interfaces to blas library.
void ger(BLASHandle< PlatformKind::CUDA > &handle, const int m, const int n, const float &alpha, const float *const x, const int incx, const float *const y, const int incy, float *const A, const int lda)
void gemv_batched(BLASHandle< PlatformKind::CUDA > &handle, const char trans, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const int batch_count)
sycl::event gemv_batched(sycl::queue &handle, const char trans, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const size_t batch_count, const std::vector< sycl::event > &events={})
in-house version of gemv_batch implemented in SYCL. Can be dropped if we have vendor optimized versio...
void gemm_batched(BLASHandle< PlatformKind::CUDA > &handle, const char transa, const char transb, int m, int n, int k, const float &alpha, const float *const A[], int lda, const float *const B[], int ldb, const float &beta, float *const C[], int ldc, int batchCount)
void gemv(BLASHandle< PlatformKind::SYCL > &handle, const char trans, const int m, const int n, const T &alpha, const T *const A, const int lda, const T *const x, const int incx, const T &beta, T *const y, const int incy)
void copy_batched(BLASHandle< PlatformKind::CUDA > &handle, const int n, const T *const in[], const int incx, T *const out[], const int incy, const int batch_count)
void gemm(BLASHandle< PlatformKind::SYCL > &handle, const char transa, const char transb, int m, int n, int k, const T &alpha, const T *A, int lda, const T *B, int ldb, const T &beta, T *C, int ldc)
void ger(BLASHandle< PlatformKind::SYCL > &handle, const int m, const int n, const T &alpha, const T *const x, const int incx, const T *const y, const int incy, T *const A, const int lda)
BLASHandle(Queue< PlatformKind::SYCL > &queue)
void ger_batched(BLASHandle< PlatformKind::CUDA > &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const int batch_count)
double B(double x, int k, int i, const std::vector< double > &t)
void gemv(BLASHandle< PlatformKind::CUDA > &handle, const char trans, const int m, const int n, const float &alpha, const float *const A, const int lda, const float *const x, const int incx, const float &beta, float *const y, const int incy)
std::int64_t syclBLAS_int