QMCPACK
syclBLAS.cpp
Go to the documentation of this file.
1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2022 QMCPACK developers.
6 //
7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8 //
9 // File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
10 //////////////////////////////////////////////////////////////////////////////////////
11 
12 
13 #include "syclBLAS.hpp"
14 #include "oneapi/mkl/blas.hpp"
15 
16 namespace qmcplusplus
17 {
18 namespace syclBLAS
19 {
20 template<typename T>
21 sycl::event gemv(sycl::queue& handle,
22  const char trans,
23  const int m,
24  const int n,
25  const T alpha,
26  const T* const A,
27  const int lda,
28  const T* const x,
29  const int incx,
30  const T beta,
31  T* const y,
32  const int incy,
33  const std::vector<sycl::event>& events)
34 {
35  return oneapi::mkl::blas::gemv(handle, convertTransEnum(trans), m, n, alpha, A, lda, x, incx, beta, y, incy, events);
36 }
37 
38 template sycl::event gemv(sycl::queue& handle,
39  const char trans,
40  const int m,
41  const int n,
42  const double alpha,
43  const double* const A,
44  const int lda,
45  const double* const x,
46  const int incx,
47  const double beta,
48  double* const y,
49  const int incy,
50  const std::vector<sycl::event>& events);
51 
52 template sycl::event gemv(sycl::queue& handle,
53  const char trans,
54  const int m,
55  const int n,
56  const float alpha,
57  const float* const A,
58  const int lda,
59  const float* const x,
60  const int incx,
61  const float beta,
62  float* const y,
63  const int incy,
64  const std::vector<sycl::event>& events);
65 
66 template sycl::event gemv(sycl::queue& handle,
67  const char trans,
68  const int m,
69  const int n,
70  const std::complex<double> alpha,
71  const std::complex<double>* const A,
72  const int lda,
73  const std::complex<double>* const x,
74  const int incx,
75  const std::complex<double> beta,
76  std::complex<double>* const y,
77  const int incy,
78  const std::vector<sycl::event>& events);
79 
80 template sycl::event gemv(sycl::queue& handle,
81  const char trans,
82  const int m,
83  const int n,
84  const std::complex<float> alpha,
85  const std::complex<float>* const A,
86  const int lda,
87  const std::complex<float>* const x,
88  const int incx,
89  const std::complex<float> beta,
90  std::complex<float>* const y,
91  const int incy,
92  const std::vector<sycl::event>& events);
93 
94 /** gemv trans = 'T' case. COLS refers to columns of the m x n column-major Fortran matrix A.
95  */
96 template<typename T, unsigned COLBS>
97 sycl::event gemvT_batched_impl(sycl::queue& handle,
98  const int m,
99  const int n,
100  const T* alpha,
101  const T* const A[],
102  const int lda,
103  const T* const x[],
104  const int incx,
105  const T* beta,
106  T* const y[],
107  const int incy,
108  const size_t batch_count,
109  const std::vector<sycl::event>& events = {})
110 {
111  if (m == 0 || n == 0 || batch_count == 0)
112  return sycl::event();
113 
114  const int num_col_blocks = (n + COLBS - 1) / COLBS;
115  return handle.parallel_for(sycl::nd_range<2>{{batch_count, num_col_blocks * COLBS}, {1, COLBS}},
116  [=](sycl::nd_item<2> item) {
117  const unsigned batch = item.get_group(0);
118  const int col = item.get_global_id(1);
119  if (col < n)
120  {
121  T sum(0);
122  for (int row = 0; row < m; row++)
123  sum += A[batch][col * lda + row] * x[batch][row * incx];
124  if (beta[batch] == T(0))
125  y[batch][col * incy] = alpha[batch] * sum; // protecting NaN from y_iw
126  else
127  y[batch][col * incy] = alpha[batch] * sum + beta[batch] * y[batch][col * incy];
128  }
129  });
130 }
131 
132 /** gemv trans = 'N' case. ROW refers to rows of the m x n column-major Fortran matrix A.
133  */
134 template<typename T, unsigned ROWBS>
135 sycl::event gemvN_batched_impl(sycl::queue& handle,
136  const int m,
137  const int n,
138  const T* alpha,
139  const T* const A[],
140  const int lda,
141  const T* const x[],
142  const int incx,
143  const T* beta,
144  T* const y[],
145  const int incy,
146  const size_t batch_count,
147  const std::vector<sycl::event>& events = {})
148 {
149  if (m == 0 || n == 0 || batch_count == 0)
150  return sycl::event();
151 
152  const int num_row_blocks = (m + ROWBS - 1) / ROWBS;
153  return handle.parallel_for(sycl::nd_range<2>{{batch_count, num_row_blocks * ROWBS}, {1, ROWBS}},
154  [=](sycl::nd_item<2> item) {
155  const unsigned batch = item.get_group(0);
156  const int row = item.get_global_id(1);
157  if (row < m)
158  {
159  T sum(0);
160  for (int col = 0; col < n; col++)
161  sum += A[batch][col * lda + row] * x[batch][col * incx];
162  if (beta[batch] == T(0))
163  y[batch][row * incy] = alpha[batch] * sum; // protecting NaN from y_iw
164  else
165  y[batch][row * incy] = alpha[batch] * sum + beta[batch] * y[batch][row * incy];
166  }
167  });
168 }
169 
170 template<>
171 sycl::event gemv_batched<float>(sycl::queue& handle,
172  const char trans,
173  const int m,
174  const int n,
175  const float* alpha,
176  const float* const A[],
177  const int lda,
178  const float* const x[],
179  const int incx,
180  const float* beta,
181  float* const y[],
182  const int incy,
183  const size_t batch_count,
184  const std::vector<sycl::event>& events)
185 {
186  if (trans == 'N' || trans == 'n')
187  return gemvN_batched_impl<float, 64>(handle, m, n, alpha, A, lda, x, incx, beta, y, incy, batch_count);
188  else if (trans == 'T' || trans == 't')
189  return gemvT_batched_impl<float, 64>(handle, m, n, alpha, A, lda, x, incx, beta, y, incy, batch_count);
190  else
191  throw std::runtime_error("syclBLAS::gemv_batched only supports 'N', 'T', 'C', 'n'. Input value is " +
192  std::string(1, trans));
193 }
194 
195 template<>
196 sycl::event gemv_batched<double>(sycl::queue& handle,
197  const char trans,
198  const int m,
199  const int n,
200  const double* alpha,
201  const double* const A[],
202  const int lda,
203  const double* const x[],
204  const int incx,
205  const double* beta,
206  double* const y[],
207  const int incy,
208  const size_t batch_count,
209  const std::vector<sycl::event>& events)
210 {
211  if (trans == 'N' || trans == 'n')
212  return gemvN_batched_impl<double, 64>(handle, m, n, alpha, A, lda, x, incx, beta, y, incy, batch_count);
213  else if (trans == 'T' || trans == 't')
214  return gemvT_batched_impl<double, 64>(handle, m, n, alpha, A, lda, x, incx, beta, y, incy, batch_count);
215  else
216  throw std::runtime_error("syclBLAS::gemv_batched only supports 'N', 'T', 'C', 'n'. Input value is " +
217  std::string(1, trans));
218 }
219 
220 template<>
221 sycl::event gemv_batched<std::complex<float>>(sycl::queue& handle,
222  const char trans,
223  const int m,
224  const int n,
225  const std::complex<float>* alpha,
226  const std::complex<float>* const A[],
227  const int lda,
228  const std::complex<float>* const x[],
229  const int incx,
230  const std::complex<float>* beta,
231  std::complex<float>* const y[],
232  const int incy,
233  const size_t batch_count,
234  const std::vector<sycl::event>& events)
235 {
236  if (trans == 'N' || trans == 'n')
237  return gemvN_batched_impl<std::complex<float>, 64>(handle, m, n, alpha, A, lda, x, incx, beta, y, incy,
238  batch_count);
239  else if (trans == 'T' || trans == 't')
240  return gemvT_batched_impl<std::complex<float>, 64>(handle, m, n, alpha, A, lda, x, incx, beta, y, incy,
241  batch_count);
242  else
243  throw std::runtime_error("syclBLAS::gemv_batched only supports 'N', 'T', 'C', 'n'. Input value is " +
244  std::string(1, trans));
245 }
246 
247 template<>
248 sycl::event gemv_batched<std::complex<double>>(sycl::queue& handle,
249  const char trans,
250  const int m,
251  const int n,
252  const std::complex<double>* alpha,
253  const std::complex<double>* const A[],
254  const int lda,
255  const std::complex<double>* const x[],
256  const int incx,
257  const std::complex<double>* beta,
258  std::complex<double>* const y[],
259  const int incy,
260  const size_t batch_count,
261  const std::vector<sycl::event>& events)
262 {
263  if (trans == 'N' || trans == 'n')
264  return gemvN_batched_impl<std::complex<double>, 64>(handle, m, n, alpha, A, lda, x, incx, beta, y, incy,
265  batch_count);
266  else if (trans == 'T' || trans == 't')
267  return gemvT_batched_impl<std::complex<double>, 64>(handle, m, n, alpha, A, lda, x, incx, beta, y, incy,
268  batch_count);
269  else
270  throw std::runtime_error("syclBLAS::gemv_batched only supports 'N', 'T', 'C', 'n'. Input value is " +
271  std::string(1, trans));
272 }
273 
274 template<typename T>
275 sycl::event gemm(sycl::queue& handle,
276  const char tA,
277  const char tB,
278  const int m,
279  const int n,
280  const int k,
281  const T alpha,
282  const T* A,
283  const int lda,
284  const T* B,
285  const int ldb,
286  const T beta,
287  T* C,
288  const int ldc,
289  const std::vector<sycl::event>& events)
290 {
291  return oneapi::mkl::blas::gemm(handle, convertTransEnum(tA), convertTransEnum(tB), m, n, k, alpha, A, lda, B, ldb,
292  beta, C, ldc, events);
293 }
294 
295 
296 template sycl::event gemm(sycl::queue& handle,
297  const char tA,
298  const char tB,
299  const int m,
300  const int n,
301  const int k,
302  const float alpha,
303  const float* const A,
304  const int lda,
305  const float* const B,
306  const int ldb,
307  const float beta,
308  float* const C,
309  const int ldc,
310  const std::vector<sycl::event>& events);
311 
312 template sycl::event gemm(sycl::queue& handle,
313  const char tA,
314  const char tB,
315  const int m,
316  const int n,
317  const int k,
318  const double alpha,
319  const double* const A,
320  const int lda,
321  const double* const B,
322  const int ldb,
323  const double beta,
324  double* const C,
325  const int ldc,
326  const std::vector<sycl::event>& events);
327 
328 template sycl::event gemm(sycl::queue& handle,
329  const char tA,
330  const char tB,
331  const int m,
332  const int n,
333  const int k,
334  const std::complex<float> alpha,
335  const std::complex<float>* const A,
336  const int lda,
337  const std::complex<float>* const B,
338  const int ldb,
339  const std::complex<float> beta,
340  std::complex<float>* const C,
341  const int ldc,
342  const std::vector<sycl::event>& events);
343 
344 template sycl::event gemm(sycl::queue& handle,
345  const char tA,
346  const char tB,
347  const int m,
348  const int n,
349  const int k,
350  const std::complex<double> alpha,
351  const std::complex<double>* const A,
352  const int lda,
353  const std::complex<double>* const B,
354  const int ldb,
355  const std::complex<double> beta,
356  std::complex<double>* const C,
357  const int ldc,
358  const std::vector<sycl::event>& events);
359 
360 template<typename T, int TILE_SIZE, int ROWBS>
361 sycl::event ger_batched_impl(sycl::queue& handle,
362  const int m,
363  const int n,
364  const T* alpha,
365  const T* const x[],
366  const int incx,
367  const T* const y[],
368  const int incy,
369  T* const A[],
370  const int lda,
371  const size_t batch_count,
372  const std::vector<sycl::event>& events)
373 {
374  static_assert(ROWBS <= TILE_SIZE, "ROWBS cannot be larger than TILE_SIZE!");
375  if (m == 0 || n == 0 || batch_count == 0)
376  return sycl::event();
377 
378  // A is m x n in Fortran, n x m in C.
379  constexpr size_t tile_size = TILE_SIZE;
380  constexpr size_t block_rows = ROWBS;
381  // the computation is tiled and distributed.
382  const size_t row_tiles = (n + tile_size - 1) / tile_size;
383  const size_t col_tiles = (m + tile_size - 1) / tile_size;
384 
385  return handle.parallel_for(sycl::nd_range<3>{{batch_count, row_tiles * block_rows, col_tiles * tile_size},
386  {1, block_rows, tile_size}},
387  [=](sycl::nd_item<3> item) {
388  const unsigned batch = item.get_group(0);
389  const unsigned thX = item.get_local_id(2);
390  const unsigned thY = item.get_local_id(1);
391  const unsigned column = item.get_group(2) * tile_size + thX;
392  const unsigned row_offset = item.get_group(1) * tile_size + thY;
393  if (column < m)
394  {
395  const T alphaX = alpha[batch] * x[batch][column * incx];
396  for (unsigned j = 0; j < tile_size; j += block_rows)
397  if (const unsigned row = row_offset + j; row < n)
398  A[batch][row * lda + column] += alphaX * y[batch][row * incy];
399  }
400  });
401 }
402 
403 template<>
404 sycl::event ger_batched<float>(sycl::queue& handle,
405  const int m,
406  const int n,
407  const float* alpha,
408  const float* const x[],
409  const int incx,
410  const float* const y[],
411  const int incy,
412  float* const A[],
413  const int lda,
414  const size_t batch_count,
415  const std::vector<sycl::event>& events)
416 {
417  return ger_batched_impl<float, 32, 8>(handle, m, n, alpha, x, incx, y, incy, A, lda, batch_count, events);
418 }
419 
420 template<>
421 sycl::event ger_batched<double>(sycl::queue& handle,
422  const int m,
423  const int n,
424  const double* alpha,
425  const double* const x[],
426  const int incx,
427  const double* const y[],
428  const int incy,
429  double* const A[],
430  const int lda,
431  const size_t batch_count,
432  const std::vector<sycl::event>& events)
433 {
434  return ger_batched_impl<double, 32, 8>(handle, m, n, alpha, x, incx, y, incy, A, lda, batch_count, events);
435 }
436 
437 template<>
438 sycl::event ger_batched<std::complex<float>>(sycl::queue& handle,
439  const int m,
440  const int n,
441  const std::complex<float>* alpha,
442  const std::complex<float>* const x[],
443  const int incx,
444  const std::complex<float>* const y[],
445  const int incy,
446  std::complex<float>* const A[],
447  const int lda,
448  const size_t batch_count,
449  const std::vector<sycl::event>& events)
450 {
451  return ger_batched_impl<std::complex<float>, 32, 8>(handle, m, n, alpha, x, incx, y, incy, A, lda, batch_count,
452  events);
453 }
454 
455 template<>
456 sycl::event ger_batched<std::complex<double>>(sycl::queue& handle,
457  const int m,
458  const int n,
459  const std::complex<double>* alpha,
460  const std::complex<double>* const x[],
461  const int incx,
462  const std::complex<double>* const y[],
463  const int incy,
464  std::complex<double>* const A[],
465  const int lda,
466  const size_t batch_count,
467  const std::vector<sycl::event>& events)
468 {
469  return ger_batched_impl<std::complex<double>, 32, 8>(handle, m, n, alpha, x, incx, y, incy, A, lda, batch_count,
470  events);
471 }
472 
473 //transpose
474 template<typename T1, typename T2>
475 sycl::event transpose(sycl::queue& q,
476  const T1* restrict in,
477  int m,
478  int lda,
479  T2* restrict out,
480  int n,
481  int ldb,
482  const std::vector<sycl::event>& events)
483 {
484  constexpr size_t tile_size = 16;
485  const size_t m_max = ((m + tile_size - 1) / tile_size) * tile_size;
486  const size_t n_max = ((n + tile_size - 1) / tile_size) * tile_size;
487 
488  return q.submit([&](sycl::handler& cgh) {
489  cgh.depends_on(events);
490  sycl::local_accessor<T2, 2> tile(sycl::range<2>(tile_size, tile_size + 1), cgh);
491 
492  cgh.parallel_for(sycl::nd_range<2>{{m_max, n_max}, {tile_size, tile_size}}, [=](sycl::nd_item<2> item) {
493  unsigned x = item.get_global_id(1);
494  unsigned y = item.get_global_id(0);
495  unsigned xth = item.get_local_id(1);
496  unsigned yth = item.get_local_id(0);
497 
498  if (x < n && y < m)
499  tile[yth][xth] = in[(y)*lda + x];
500  item.barrier(sycl::access::fence_space::local_space);
501 
502  x = item.get_group(0) * tile_size + xth;
503  y = item.get_group(1) * tile_size + yth;
504  if (x < m && y < n)
505  out[(y)*ldb + x] = tile[xth][yth];
506  });
507  });
508 }
509 
510 template sycl::event transpose(sycl::queue& q,
511  const float* restrict in,
512  int m,
513  int lda,
514  double* restrict out,
515  int n,
516  int ldb,
517  const std::vector<sycl::event>& events);
518 
519 template sycl::event transpose(sycl::queue& q,
520  const double* restrict in,
521  int m,
522  int lda,
523  double* restrict out,
524  int n,
525  int ldb,
526  const std::vector<sycl::event>& events);
527 
528 template sycl::event transpose(sycl::queue& q,
529  const std::complex<float>* restrict in,
530  int m,
531  int lda,
532  std::complex<double>* restrict out,
533  int n,
534  int ldb,
535  const std::vector<sycl::event>& events);
536 
537 template sycl::event transpose(sycl::queue& q,
538  const std::complex<double>* restrict in,
539  int m,
540  int lda,
541  std::complex<double>* restrict out,
542  int n,
543  int ldb,
544  const std::vector<sycl::event>& events);
545 
546 //copy_n for mixed precision
547 template<typename T1, typename T2>
548 sycl::event copy_n(sycl::queue& aq,
549  const T1* restrict VA,
550  size_t array_size,
551  T2* restrict VC,
552  const std::vector<sycl::event>& events)
553 {
554  if (array_size == 0)
555  return sycl::event();
556  constexpr size_t tile_size = 64;
557  const size_t a_max = ((array_size + tile_size - 1) / tile_size) * tile_size;
558  return aq.parallel_for(sycl::range<1>{a_max}, events, [=](sycl::id<1> id) {
559  if (id < array_size)
560  VC[id] = static_cast<T2>(VA[id]);
561  });
562 }
563 
564 template sycl::event copy_n(sycl::queue& aq,
565  const double* restrict VA,
566  size_t array_size,
567  float* restrict VC,
568  const std::vector<sycl::event>& events);
569 
570 template sycl::event copy_n(sycl::queue& aq,
571  const std::complex<double>* restrict VA,
572  size_t array_size,
573  std::complex<float>* restrict VC,
574  const std::vector<sycl::event>& events);
575 
576 } // namespace syclBLAS
577 
578 } // namespace qmcplusplus
sycl::event ger_batched< float >(sycl::queue &handle, const int m, const int n, const float *alpha, const float *const x[], const int incx, const float *const y[], const int incy, float *const A[], const int lda, const size_t batch_count, const std::vector< sycl::event > &events)
Definition: syclBLAS.cpp:404
helper functions for EinsplineSetBuilder
Definition: Configuration.h:43
sycl::event gemvT_batched_impl(sycl::queue &handle, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const size_t batch_count, const std::vector< sycl::event > &events={})
gemv trans = &#39;T&#39; case.
Definition: syclBLAS.cpp:97
sycl::event gemv_batched< float >(sycl::queue &handle, const char trans, const int m, const int n, const float *alpha, const float *const A[], const int lda, const float *const x[], const int incx, const float *beta, float *const y[], const int incy, const size_t batch_count, const std::vector< sycl::event > &events)
Definition: syclBLAS.cpp:171
oneapi::mkl::transpose convertTransEnum(char trans)
Definition: syclBLAS.hpp:28
sycl::event transpose(sycl::queue &q, const T1 *restrict in, int m, int lda, T2 *restrict out, int n, int ldb, const std::vector< sycl::event > &events)
Definition: syclBLAS.cpp:475
sycl::event ger_batched< double >(sycl::queue &handle, const int m, const int n, const double *alpha, const double *const x[], const int incx, const double *const y[], const int incy, double *const A[], const int lda, const size_t batch_count, const std::vector< sycl::event > &events)
Definition: syclBLAS.cpp:421
template sycl::event gemm(sycl::queue &handle, const char tA, const char tB, const int m, const int n, const int k, const std::complex< double > alpha, const std::complex< double > *const A, const int lda, const std::complex< double > *const B, const int ldb, const std::complex< double > beta, std::complex< double > *const C, const int ldc, const std::vector< sycl::event > &events)
sycl::event ger_batched_impl(sycl::queue &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const size_t batch_count, const std::vector< sycl::event > &events)
Definition: syclBLAS.cpp:361
sycl::event gemv(sycl::queue &handle, const char trans, const int m, const int n, const T alpha, const T *const A, const int lda, const T *const x, const int incx, const T beta, T *const y, const int incy, const std::vector< sycl::event > &events)
Definition: syclBLAS.cpp:21
sycl::event gemm(sycl::queue &handle, const char tA, const char tB, const int m, const int n, const int k, const T alpha, const T *A, const int lda, const T *B, const int ldb, const T beta, T *C, const int ldc, const std::vector< sycl::event > &events)
Definition: syclBLAS.cpp:275
sycl::event copy_n(sycl::queue &aq, const T1 *restrict VA, size_t array_size, T2 *restrict VC, const std::vector< sycl::event > &events)
Definition: syclBLAS.cpp:548
sycl::event gemvN_batched_impl(sycl::queue &handle, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const size_t batch_count, const std::vector< sycl::event > &events={})
gemv trans = &#39;N&#39; case.
Definition: syclBLAS.cpp:135
sycl::event gemv_batched< double >(sycl::queue &handle, const char trans, const int m, const int n, const double *alpha, const double *const A[], const int lda, const double *const x[], const int incx, const double *beta, double *const y[], const int incy, const size_t batch_count, const std::vector< sycl::event > &events)
Definition: syclBLAS.cpp:196
template sycl::event gemv(sycl::queue &handle, const char trans, const int m, const int n, const std::complex< float > alpha, const std::complex< float > *const A, const int lda, const std::complex< float > *const x, const int incx, const std::complex< float > beta, std::complex< float > *const y, const int incy, const std::vector< sycl::event > &events)
double B(double x, int k, int i, const std::vector< double > &t)