QMCPACK
syclBLAS.hpp
Go to the documentation of this file.
1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2022 QMCPACK developers.
6 //
7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8 //
9 // File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
10 //////////////////////////////////////////////////////////////////////////////////////
11 
12 
13 #ifndef QMCPLUSPLUS_SYCL_BLAS_H
14 #define QMCPLUSPLUS_SYCL_BLAS_H
15 
16 #include <complex>
17 #include <sycl/sycl.hpp>
18 #include <oneapi/mkl/blas.hpp>
19 
20 namespace qmcplusplus
21 {
22 namespace syclBLAS
23 {
24 using syclBLAS_int = std::int64_t;
25 using syclBLAS_status = sycl::event;
27 
29 {
30  if (trans == 'N' || trans == 'n')
31  return oneapi::mkl::transpose::nontrans;
32  else if (trans == 'T' || trans == 't')
33  return oneapi::mkl::transpose::trans;
34  else if (trans == 'C' || trans == 'c')
35  return oneapi::mkl::transpose::conjtrans;
36  else
37  throw std::runtime_error(
38  "syclBLAS::convertTransEnum trans can only be 'N', 'T', 'C', 'n', 't', 'c'. Input value is " +
39  std::string(1, trans));
40 }
41 
42 template<typename T>
43 sycl::event gemv(sycl::queue& handle,
44  const char trans,
45  const int m,
46  const int n,
47  const T alpha,
48  const T* const A,
49  const int lda,
50  const T* const x,
51  const int incx,
52  const T beta,
53  T* const y,
54  const int incy,
55  const std::vector<sycl::event>& events = {});
56 
57 /// in-house version of gemv_batch implemented in SYCL. Can be dropped if we have vendor optimized versions
58 template<typename T>
59 sycl::event gemv_batched(sycl::queue& handle,
60  const char trans,
61  const int m,
62  const int n,
63  const T* alpha,
64  const T* const A[],
65  const int lda,
66  const T* const x[],
67  const int incx,
68  const T* beta,
69  T* const y[],
70  const int incy,
71  const size_t batch_count,
72  const std::vector<sycl::event>& events = {});
73 
74 template<typename T>
75 sycl::event gemm(sycl::queue& handle,
76  const char tA,
77  const char tB,
78  const int m,
79  const int n,
80  const int k,
81  const T alpha,
82  const T* const A,
83  const int lda,
84  const T* const B,
85  const int ldb,
86  const T beta,
87  T* const C,
88  const int ldc,
89  const std::vector<sycl::event>& events = {});
90 
91 /// in-house version of ger_batch implemented in SYCL. Can be dropped if we have vendor optimized versions
92 template<typename T>
93 sycl::event ger_batched(sycl::queue& handle,
94  const int m,
95  const int n,
96  const T* alpha,
97  const T* const x[],
98  const int incx,
99  const T* const y[],
100  const int incy,
101  T* const A[],
102  const int lda,
103  const size_t batch_count,
104  const std::vector<sycl::event>& events = {});
105 
106 template<typename T1, typename T2>
107 sycl::event transpose(sycl::queue& q,
108  const T1* in,
109  int m,
110  int lda,
111  T2* out,
112  int n,
113  int ldb,
114  const std::vector<sycl::event>& events = {});
115 
116 template<typename T1, typename T2>
117 sycl::event copy_n(sycl::queue& aq,
118  const T1* VA,
119  size_t array_size,
120  T2* VC,
121  const std::vector<sycl::event>& events = {});
122 
123 } // namespace syclBLAS
124 
125 } // namespace qmcplusplus
126 #endif // QMCPLUSPLUS_OMPBLAS_H
helper functions for EinsplineSetBuilder
Definition: Configuration.h:43
sycl::event syclBLAS_status
Definition: syclBLAS.hpp:25
sycl::event ger_batched(sycl::queue &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const size_t batch_count, const std::vector< sycl::event > &events={})
in-house version of ger_batch implemented in SYCL. Can be dropped if we have vendor optimized version...
oneapi::mkl::transpose convertTransEnum(char trans)
Definition: syclBLAS.hpp:28
sycl::queue syclBLAS_handle
Definition: syclBLAS.hpp:26
sycl::event transpose(sycl::queue &q, const T1 *restrict in, int m, int lda, T2 *restrict out, int n, int ldb, const std::vector< sycl::event > &events)
Definition: syclBLAS.cpp:475
sycl::event gemv_batched(sycl::queue &handle, const char trans, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const size_t batch_count, const std::vector< sycl::event > &events={})
in-house version of gemv_batch implemented in SYCL. Can be dropped if we have vendor optimized versio...
sycl::event gemv(sycl::queue &handle, const char trans, const int m, const int n, const T alpha, const T *const A, const int lda, const T *const x, const int incx, const T beta, T *const y, const int incy, const std::vector< sycl::event > &events)
Definition: syclBLAS.cpp:21
sycl::event gemm(sycl::queue &handle, const char tA, const char tB, const int m, const int n, const int k, const T alpha, const T *A, const int lda, const T *B, const int ldb, const T beta, T *C, const int ldc, const std::vector< sycl::event > &events)
Definition: syclBLAS.cpp:275
sycl::event copy_n(sycl::queue &aq, const T1 *restrict VA, size_t array_size, T2 *restrict VC, const std::vector< sycl::event > &events)
Definition: syclBLAS.cpp:548
double B(double x, int k, int i, const std::vector< double > &t)
sycl::event transpose(sycl::queue &q, const T1 *in, int m, int lda, T2 *out, int n, int ldb, const std::vector< sycl::event > &events={})
std::int64_t syclBLAS_int
Definition: syclBLAS.hpp:24