14 #include "oneapi/mkl/blas.hpp" 33 const std::vector<sycl::event>& events)
35 return oneapi::mkl::blas::gemv(handle,
convertTransEnum(trans),
m,
n, alpha,
A,
lda, x, incx, beta, y, incy, events);
43 const double*
const A,
45 const double*
const x,
50 const std::vector<sycl::event>& events);
64 const std::vector<sycl::event>& events);
70 const std::complex<double> alpha,
71 const std::complex<double>*
const A,
73 const std::complex<double>*
const x,
75 const std::complex<double> beta,
76 std::complex<double>*
const y,
78 const std::vector<sycl::event>& events);
84 const std::complex<float> alpha,
85 const std::complex<float>*
const A,
87 const std::complex<float>*
const x,
89 const std::complex<float> beta,
90 std::complex<float>*
const y,
92 const std::vector<sycl::event>& events);
96 template<
typename T,
unsigned COLBS>
108 const size_t batch_count,
109 const std::vector<sycl::event>& events = {})
111 if (
m == 0 ||
n == 0 || batch_count == 0)
112 return sycl::event();
114 const int num_col_blocks = (
n + COLBS - 1) / COLBS;
115 return handle.parallel_for(sycl::nd_range<2>{{batch_count, num_col_blocks * COLBS}, {1, COLBS}},
116 [=](sycl::nd_item<2> item) {
117 const unsigned batch = item.get_group(0);
118 const int col = item.get_global_id(1);
122 for (
int row = 0; row <
m; row++)
123 sum +=
A[batch][col *
lda + row] * x[batch][row * incx];
124 if (beta[batch] == T(0))
125 y[batch][col * incy] = alpha[batch] * sum;
127 y[batch][col * incy] = alpha[batch] * sum + beta[batch] * y[batch][col * incy];
134 template<
typename T,
unsigned ROWBS>
146 const size_t batch_count,
147 const std::vector<sycl::event>& events = {})
149 if (
m == 0 ||
n == 0 || batch_count == 0)
150 return sycl::event();
152 const int num_row_blocks = (
m + ROWBS - 1) / ROWBS;
153 return handle.parallel_for(sycl::nd_range<2>{{batch_count, num_row_blocks * ROWBS}, {1, ROWBS}},
154 [=](sycl::nd_item<2> item) {
155 const unsigned batch = item.get_group(0);
156 const int row = item.get_global_id(1);
160 for (
int col = 0; col <
n; col++)
161 sum +=
A[batch][col *
lda + row] * x[batch][col * incx];
162 if (beta[batch] == T(0))
163 y[batch][row * incy] = alpha[batch] * sum;
165 y[batch][row * incy] = alpha[batch] * sum + beta[batch] * y[batch][row * incy];
176 const float*
const A[],
178 const float*
const x[],
183 const size_t batch_count,
184 const std::vector<sycl::event>& events)
186 if (trans ==
'N' || trans ==
'n')
187 return gemvN_batched_impl<float, 64>(handle,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy, batch_count);
188 else if (trans ==
'T' || trans ==
't')
189 return gemvT_batched_impl<float, 64>(handle,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy, batch_count);
191 throw std::runtime_error(
"syclBLAS::gemv_batched only supports 'N', 'T', 'C', 'n'. Input value is " +
192 std::string(1, trans));
201 const double*
const A[],
203 const double*
const x[],
208 const size_t batch_count,
209 const std::vector<sycl::event>& events)
211 if (trans ==
'N' || trans ==
'n')
212 return gemvN_batched_impl<double, 64>(handle,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy, batch_count);
213 else if (trans ==
'T' || trans ==
't')
214 return gemvT_batched_impl<double, 64>(handle,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy, batch_count);
216 throw std::runtime_error(
"syclBLAS::gemv_batched only supports 'N', 'T', 'C', 'n'. Input value is " +
217 std::string(1, trans));
221 sycl::event gemv_batched<std::complex<float>>(
sycl::queue& handle,
225 const std::complex<float>* alpha,
226 const std::complex<float>*
const A[],
228 const std::complex<float>*
const x[],
230 const std::complex<float>* beta,
231 std::complex<float>*
const y[],
233 const size_t batch_count,
234 const std::vector<sycl::event>& events)
236 if (trans ==
'N' || trans ==
'n')
237 return gemvN_batched_impl<std::complex<float>, 64>(handle,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy,
239 else if (trans ==
'T' || trans ==
't')
240 return gemvT_batched_impl<std::complex<float>, 64>(handle,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy,
243 throw std::runtime_error(
"syclBLAS::gemv_batched only supports 'N', 'T', 'C', 'n'. Input value is " +
244 std::string(1, trans));
248 sycl::event gemv_batched<std::complex<double>>(
sycl::queue& handle,
252 const std::complex<double>* alpha,
253 const std::complex<double>*
const A[],
255 const std::complex<double>*
const x[],
257 const std::complex<double>* beta,
258 std::complex<double>*
const y[],
260 const size_t batch_count,
261 const std::vector<sycl::event>& events)
263 if (trans ==
'N' || trans ==
'n')
264 return gemvN_batched_impl<std::complex<double>, 64>(handle,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy,
266 else if (trans ==
'T' || trans ==
't')
267 return gemvT_batched_impl<std::complex<double>, 64>(handle,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy,
270 throw std::runtime_error(
"syclBLAS::gemv_batched only supports 'N', 'T', 'C', 'n'. Input value is " +
271 std::string(1, trans));
289 const std::vector<sycl::event>& events)
291 return oneapi::mkl::blas::gemm(handle,
convertTransEnum(tA),
convertTransEnum(tB),
m,
n, k, alpha,
A,
lda,
B, ldb,
292 beta,
C, ldc, events);
303 const float*
const A,
305 const float*
const B,
310 const std::vector<sycl::event>& events);
319 const double*
const A,
321 const double*
const B,
326 const std::vector<sycl::event>& events);
334 const std::complex<float> alpha,
335 const std::complex<float>*
const A,
337 const std::complex<float>*
const B,
339 const std::complex<float> beta,
340 std::complex<float>*
const C,
342 const std::vector<sycl::event>& events);
350 const std::complex<double> alpha,
351 const std::complex<double>*
const A,
353 const std::complex<double>*
const B,
355 const std::complex<double> beta,
356 std::complex<double>*
const C,
358 const std::vector<sycl::event>& events);
360 template<
typename T,
int TILE_SIZE,
int ROWBS>
371 const size_t batch_count,
372 const std::vector<sycl::event>& events)
374 static_assert(ROWBS <= TILE_SIZE,
"ROWBS cannot be larger than TILE_SIZE!");
375 if (
m == 0 ||
n == 0 || batch_count == 0)
376 return sycl::event();
379 constexpr
size_t tile_size = TILE_SIZE;
380 constexpr
size_t block_rows = ROWBS;
382 const size_t row_tiles = (
n + tile_size - 1) / tile_size;
383 const size_t col_tiles = (
m + tile_size - 1) / tile_size;
385 return handle.parallel_for(sycl::nd_range<3>{{batch_count, row_tiles * block_rows, col_tiles * tile_size},
386 {1, block_rows, tile_size}},
387 [=](sycl::nd_item<3> item) {
388 const unsigned batch = item.get_group(0);
389 const unsigned thX = item.get_local_id(2);
390 const unsigned thY = item.get_local_id(1);
391 const unsigned column = item.get_group(2) * tile_size + thX;
392 const unsigned row_offset = item.get_group(1) * tile_size + thY;
395 const T alphaX = alpha[batch] * x[batch][column * incx];
396 for (
unsigned j = 0; j < tile_size; j += block_rows)
397 if (
const unsigned row = row_offset + j; row <
n)
398 A[batch][row *
lda + column] += alphaX * y[batch][row * incy];
408 const float*
const x[],
410 const float*
const y[],
414 const size_t batch_count,
415 const std::vector<sycl::event>& events)
417 return ger_batched_impl<float, 32, 8>(handle,
m,
n, alpha, x, incx, y, incy,
A,
lda, batch_count, events);
425 const double*
const x[],
427 const double*
const y[],
431 const size_t batch_count,
432 const std::vector<sycl::event>& events)
434 return ger_batched_impl<double, 32, 8>(handle,
m,
n, alpha, x, incx, y, incy,
A,
lda, batch_count, events);
441 const std::complex<float>* alpha,
442 const std::complex<float>*
const x[],
444 const std::complex<float>*
const y[],
446 std::complex<float>*
const A[],
448 const size_t batch_count,
449 const std::vector<sycl::event>& events)
451 return ger_batched_impl<std::complex<float>, 32, 8>(handle,
m,
n, alpha, x, incx, y, incy,
A,
lda, batch_count,
456 sycl::event ger_batched<std::complex<double>>(
sycl::queue& handle,
459 const std::complex<double>* alpha,
460 const std::complex<double>*
const x[],
462 const std::complex<double>*
const y[],
464 std::complex<double>*
const A[],
466 const size_t batch_count,
467 const std::vector<sycl::event>& events)
469 return ger_batched_impl<std::complex<double>, 32, 8>(handle,
m,
n, alpha, x, incx, y, incy,
A,
lda, batch_count,
474 template<
typename T1,
typename T2>
476 const T1* restrict in,
482 const std::vector<sycl::event>& events)
484 constexpr
size_t tile_size = 16;
485 const size_t m_max = ((
m + tile_size - 1) / tile_size) * tile_size;
486 const size_t n_max = ((
n + tile_size - 1) / tile_size) * tile_size;
488 return q.submit([&](sycl::handler& cgh) {
489 cgh.depends_on(events);
490 sycl::local_accessor<T2, 2> tile(sycl::range<2>(tile_size, tile_size + 1), cgh);
492 cgh.parallel_for(sycl::nd_range<2>{{m_max, n_max}, {tile_size, tile_size}}, [=](sycl::nd_item<2> item) {
493 unsigned x = item.get_global_id(1);
494 unsigned y = item.get_global_id(0);
495 unsigned xth = item.get_local_id(1);
496 unsigned yth = item.get_local_id(0);
499 tile[yth][xth] = in[(y)*
lda + x];
500 item.barrier(sycl::access::fence_space::local_space);
502 x = item.get_group(0) * tile_size + xth;
503 y = item.get_group(1) * tile_size + yth;
505 out[(y)*ldb + x] = tile[xth][yth];
511 const float* restrict in,
514 double* restrict out,
517 const std::vector<sycl::event>& events);
520 const double* restrict in,
523 double* restrict out,
526 const std::vector<sycl::event>& events);
529 const std::complex<float>* restrict in,
532 std::complex<double>* restrict out,
535 const std::vector<sycl::event>& events);
538 const std::complex<double>* restrict in,
541 std::complex<double>* restrict out,
544 const std::vector<sycl::event>& events);
547 template<
typename T1,
typename T2>
549 const T1* restrict VA,
552 const std::vector<sycl::event>& events)
555 return sycl::event();
556 constexpr
size_t tile_size = 64;
557 const size_t a_max = ((array_size + tile_size - 1) / tile_size) * tile_size;
558 return aq.parallel_for(sycl::range<1>{a_max}, events, [=](sycl::id<1> id) {
560 VC[id] =
static_cast<T2
>(VA[id]);
565 const double* restrict VA,
568 const std::vector<sycl::event>& events);
571 const std::complex<double>* restrict VA,
573 std::complex<float>* restrict VC,
574 const std::vector<sycl::event>& events);
sycl::event ger_batched< float >(sycl::queue &handle, const int m, const int n, const float *alpha, const float *const x[], const int incx, const float *const y[], const int incy, float *const A[], const int lda, const size_t batch_count, const std::vector< sycl::event > &events)
helper functions for EinsplineSetBuilder
sycl::event gemvT_batched_impl(sycl::queue &handle, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const size_t batch_count, const std::vector< sycl::event > &events={})
gemv trans = 'T' case.
sycl::event gemv_batched< float >(sycl::queue &handle, const char trans, const int m, const int n, const float *alpha, const float *const A[], const int lda, const float *const x[], const int incx, const float *beta, float *const y[], const int incy, const size_t batch_count, const std::vector< sycl::event > &events)
oneapi::mkl::transpose convertTransEnum(char trans)
sycl::event transpose(sycl::queue &q, const T1 *restrict in, int m, int lda, T2 *restrict out, int n, int ldb, const std::vector< sycl::event > &events)
sycl::event ger_batched< double >(sycl::queue &handle, const int m, const int n, const double *alpha, const double *const x[], const int incx, const double *const y[], const int incy, double *const A[], const int lda, const size_t batch_count, const std::vector< sycl::event > &events)
template sycl::event gemm(sycl::queue &handle, const char tA, const char tB, const int m, const int n, const int k, const std::complex< double > alpha, const std::complex< double > *const A, const int lda, const std::complex< double > *const B, const int ldb, const std::complex< double > beta, std::complex< double > *const C, const int ldc, const std::vector< sycl::event > &events)
sycl::event ger_batched_impl(sycl::queue &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const size_t batch_count, const std::vector< sycl::event > &events)
sycl::event gemv(sycl::queue &handle, const char trans, const int m, const int n, const T alpha, const T *const A, const int lda, const T *const x, const int incx, const T beta, T *const y, const int incy, const std::vector< sycl::event > &events)
sycl::event gemm(sycl::queue &handle, const char tA, const char tB, const int m, const int n, const int k, const T alpha, const T *A, const int lda, const T *B, const int ldb, const T beta, T *C, const int ldc, const std::vector< sycl::event > &events)
sycl::event copy_n(sycl::queue &aq, const T1 *restrict VA, size_t array_size, T2 *restrict VC, const std::vector< sycl::event > &events)
sycl::event gemvN_batched_impl(sycl::queue &handle, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const size_t batch_count, const std::vector< sycl::event > &events={})
gemv trans = 'N' case.
sycl::event gemv_batched< double >(sycl::queue &handle, const char trans, const int m, const int n, const double *alpha, const double *const A[], const int lda, const double *const x[], const int incx, const double *beta, double *const y[], const int incy, const size_t batch_count, const std::vector< sycl::event > &events)
template sycl::event gemv(sycl::queue &handle, const char trans, const int m, const int n, const std::complex< float > alpha, const std::complex< float > *const A, const int lda, const std::complex< float > *const x, const int incx, const std::complex< float > beta, std::complex< float > *const y, const int incy, const std::vector< sycl::event > &events)
double B(double x, int k, int i, const std::vector< double > &t)