17 #if !defined(OPENMP_NO_COMPLEX) 42 if (M == 0 ||
N == 0 ||
K == 0)
45 if (transa ==
'T' && transb ==
'N')
47 PRAGMA_OFFLOAD(
"omp target teams distribute parallel for collapse(2) is_device_ptr(A, B, C)")
48 for (
size_t m = 0;
m < M;
m++)
49 for (
size_t n = 0;
n <
N;
n++)
52 for (
size_t k = 0; k <
K; k++)
53 sum +=
A[
lda *
m + k] *
B[ldb *
n + k];
54 C[
n * ldc +
m] = alpha * sum + (beta == T(0) ? T(0) :
C[
n * ldc +
m] * beta);
57 else if (transa ==
'T' && transb ==
'T')
59 PRAGMA_OFFLOAD(
"omp target teams distribute parallel for collapse(2) is_device_ptr(A, B, C)")
60 for (
size_t m = 0;
m < M;
m++)
61 for (
size_t n = 0;
n <
N;
n++)
64 for (
size_t k = 0; k <
K; k++)
65 sum +=
A[
lda *
m + k] *
B[ldb * k +
n];
66 C[
n * ldc +
m] = alpha * sum + (beta == T(0) ? T(0) :
C[
n * ldc +
m] * beta);
69 else if (transa ==
'N' && transb ==
'T')
71 PRAGMA_OFFLOAD(
"omp target teams distribute parallel for collapse(2) is_device_ptr(A, B, C)")
72 for (
size_t m = 0;
m < M;
m++)
73 for (
size_t n = 0;
n <
N;
n++)
76 for (
size_t k = 0; k <
K; k++)
77 sum +=
A[
lda * k +
m] *
B[ldb * k +
n];
78 C[
n * ldc +
m] = alpha * sum + (beta == T(0) ? T(0) :
C[
n * ldc +
m] * beta);
81 else if (transa ==
'N' && transb ==
'N')
83 PRAGMA_OFFLOAD(
"omp target teams distribute parallel for collapse(2) is_device_ptr(A, B, C)")
84 for (
size_t n = 0;
n <
N;
n++)
85 for (
size_t m = 0;
m < M;
m++)
88 for (
size_t k = 0; k <
K; k++)
89 sum +=
A[
lda * k +
m] *
B[ldb *
n + k];
90 C[
n * ldc +
m] = alpha * sum + (beta == T(0) ? T(0) :
C[
n * ldc +
m] * beta);
94 throw std::runtime_error(
"Error: trans=='C' not yet implemented for ompBLAS::gemm.");
107 const float*
const A,
109 const float*
const B,
115 return gemm_impl(handle, transa, transb, M,
N,
K, alpha,
A,
lda,
B, ldb, beta,
C, ldc);
126 const double*
const A,
128 const double*
const B,
134 return gemm_impl(handle, transa, transb, M,
N,
K, alpha,
A,
lda,
B, ldb, beta,
C, ldc);
137 #if !defined(OPENMP_NO_COMPLEX) 145 const std::complex<float>& alpha,
146 const std::complex<float>*
const A,
148 const std::complex<float>*
const B,
150 const std::complex<float>& beta,
151 std::complex<float>*
const C,
154 return gemm_impl(handle, transa, transb, M,
N,
K, alpha,
A,
lda,
B, ldb, beta,
C, ldc);
164 const std::complex<double>& alpha,
165 const std::complex<double>*
const A,
167 const std::complex<double>*
const B,
169 const std::complex<double>& beta,
170 std::complex<double>*
const C,
173 return gemm_impl(handle, transa, transb, M,
N,
K, alpha,
A,
lda,
B, ldb, beta,
C, ldc);
185 const T*
const Aarray[],
187 const T*
const Barray[],
192 const int batch_count)
194 if (M == 0 ||
N == 0 ||
K == 0 || batch_count == 0)
197 if (transa ==
'T' && transb ==
'N')
199 PRAGMA_OFFLOAD(
"omp target teams distribute is_device_ptr(Aarray, Barray, Carray)")
200 for (
size_t iw = 0; iw < batch_count; iw++)
205 PRAGMA_OFFLOAD(
"omp parallel for collapse(2)")
206 for (
size_t m = 0;
m < M;
m++)
207 for (
size_t n = 0;
n <
N;
n++)
210 for (
size_t k = 0; k <
K; k++)
211 sum +=
A[
lda *
m + k] *
B[ldb *
n + k];
212 C[
n * ldc +
m] = alpha * sum + (beta == T(0) ? T(0) :
C[
n * ldc +
m] * beta);
216 else if (transa ==
'T' && transb ==
'T')
218 PRAGMA_OFFLOAD(
"omp target teams distribute is_device_ptr(Aarray, Barray, Carray)")
219 for (
size_t iw = 0; iw < batch_count; iw++)
224 PRAGMA_OFFLOAD(
"omp parallel for collapse(2)")
225 for (
size_t m = 0;
m < M;
m++)
226 for (
size_t n = 0;
n <
N;
n++)
229 for (
size_t k = 0; k <
K; k++)
230 sum +=
A[
lda *
m + k] *
B[ldb * k +
n];
231 C[
n * ldc +
m] = alpha * sum + (beta == T(0) ? T(0) :
C[
n * ldc +
m] * beta);
235 else if (transa ==
'N' && transb ==
'T')
237 PRAGMA_OFFLOAD(
"omp target teams distribute is_device_ptr(Aarray, Barray, Carray)")
238 for (
size_t iw = 0; iw < batch_count; iw++)
243 PRAGMA_OFFLOAD(
"omp parallel for collapse(2)")
244 for (
size_t m = 0;
m < M;
m++)
245 for (
size_t n = 0;
n <
N;
n++)
248 for (
size_t k = 0; k <
K; k++)
249 sum +=
A[
lda * k +
m] *
B[ldb * k +
n];
250 C[
n * ldc +
m] = alpha * sum + (beta == T(0) ? T(0) :
C[
n * ldc +
m] * beta);
254 else if (transa ==
'N' && transb ==
'N')
256 PRAGMA_OFFLOAD(
"omp target teams distribute is_device_ptr(Aarray, Barray, Carray)")
257 for (
size_t iw = 0; iw < batch_count; iw++)
262 PRAGMA_OFFLOAD(
"omp parallel for collapse(2)")
263 for (
size_t n = 0;
n <
N;
n++)
264 for (
size_t m = 0;
m < M;
m++)
267 for (
size_t k = 0; k <
K; k++)
268 sum +=
A[
lda * k +
m] *
B[ldb *
n + k];
269 C[
n * ldc +
m] = alpha * sum + (beta == T(0) ? T(0) :
C[
n * ldc +
m] * beta);
274 throw std::runtime_error(
"Error: trans=='C' not yet implemented for ompBLAS::gemm.");
287 const float*
const A[],
289 const float*
const B[],
294 const int batch_count)
296 return gemm_batched_impl(handle, transa, transb, M,
N,
K, alpha,
A,
lda,
B, ldb, beta,
C, ldc, batch_count);
307 const double*
const A[],
309 const double*
const B[],
314 const int batch_count)
316 return gemm_batched_impl(handle, transa, transb, M,
N,
K, alpha,
A,
lda,
B, ldb, beta,
C, ldc, batch_count);
319 #if !defined(OPENMP_NO_COMPLEX) 327 const std::complex<float>& alpha,
328 const std::complex<float>*
const A[],
330 const std::complex<float>*
const B[],
332 const std::complex<float>& beta,
333 std::complex<float>*
const C[],
335 const int batch_count)
337 return gemm_batched_impl(handle, transa, transb, M,
N,
K, alpha,
A,
lda,
B, ldb, beta,
C, ldc, batch_count);
347 const std::complex<double>& alpha,
348 const std::complex<double>*
const A[],
350 const std::complex<double>*
const B[],
352 const std::complex<double>& beta,
353 std::complex<double>*
const C[],
355 const int batch_count)
357 return gemm_batched_impl(handle, transa, transb, M,
N,
K, alpha,
A,
lda,
B, ldb, beta,
C, ldc, batch_count);
375 if (
m == 0 ||
n == 0)
380 if (incx != 1 || incy != 1)
381 throw std::runtime_error(
"incx!=1 or incy!=1 are not implemented in ompBLAS::gemv_impl trans='T'!");
383 PRAGMA_OFFLOAD(
"omp target teams distribute num_teams(n) is_device_ptr(A, x, y)")
384 for (uint32_t i = 0; i <
n; i++)
387 PRAGMA_OFFLOAD(
"omp parallel for simd reduction(+: dot_sum)")
388 for (uint32_t j = 0; j <
m; j++)
389 dot_sum += x[j] *
A[i *
lda + j];
391 y[i] = alpha * dot_sum;
393 y[i] = alpha * dot_sum + beta * y[i];
397 else if (trans ==
'N')
399 if (incx != 1 || incy != 1)
400 throw std::runtime_error(
"incx !=1 or incy != 1 are not implemented in ompBLAS::gemv_impl trans='N'!");
402 PRAGMA_OFFLOAD(
"omp target teams distribute num_teams(m) is_device_ptr(A, x, y)")
403 for (uint32_t i = 0; i <
m; i++)
406 PRAGMA_OFFLOAD(
"omp parallel for simd reduction(+: dot_sum)")
407 for (uint32_t j = 0; j <
n; j++)
408 dot_sum += x[j] *
A[j *
lda + i];
410 y[i] = alpha * dot_sum;
412 y[i] = alpha * dot_sum + beta * y[i];
417 throw std::runtime_error(
"Error: trans=='C' not yet implemented for ompBLAS::gemv_impl.");
426 const float*
const A,
428 const float*
const x,
434 return gemv_impl(handle, trans,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy);
443 const double*
const A,
445 const double*
const x,
451 return gemv_impl(handle, trans,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy);
454 #if !defined(OPENMP_NO_COMPLEX) 460 const std::complex<float> alpha,
461 const std::complex<float>*
const A,
463 const std::complex<float>*
const x,
465 const std::complex<float> beta,
466 std::complex<float>*
const y,
469 return gemv_impl(handle, trans,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy);
477 const std::complex<double> alpha,
478 const std::complex<double>*
const A,
480 const std::complex<double>*
const x,
482 const std::complex<double> beta,
483 std::complex<double>*
const y,
486 return gemv_impl(handle, trans,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy);
504 const int batch_count)
506 if (
m == 0 ||
n == 0 || batch_count == 0)
512 throw std::runtime_error(
"incx!=1 are not implemented in ompBLAS::gemv_batched_impl trans='T'!");
514 PRAGMA_OFFLOAD(
"omp target teams distribute collapse(2) num_teams(batch_count * n) \ 515 is_device_ptr(A, x, y, alpha, beta)")
516 for (uint32_t ib = 0; ib < batch_count; ib++)
517 for (uint32_t i = 0; i <
n; i++)
520 PRAGMA_OFFLOAD(
"omp parallel for simd reduction(+: dot_sum)")
521 for (uint32_t j = 0; j <
m; j++)
522 dot_sum += x[ib][j] *
A[ib][i *
lda + j];
523 if (beta[ib] == T(0))
524 y[ib][i * incy] = alpha[ib] * dot_sum;
526 y[ib][i * incy] = alpha[ib] * dot_sum + beta[ib] * y[ib][i * incy];
530 else if (trans ==
'N')
533 throw std::runtime_error(
"incx!=1 are not implemented in ompBLAS::gemv_batched_impl trans='N'!");
535 PRAGMA_OFFLOAD(
"omp target teams distribute collapse(2) num_teams(batch_count * n) \ 536 is_device_ptr(A, x, y, alpha, beta)")
537 for (uint32_t ib = 0; ib < batch_count; ib++)
538 for (uint32_t i = 0; i <
m; i++)
541 PRAGMA_OFFLOAD(
"omp parallel for simd reduction(+: dot_sum)")
542 for (uint32_t j = 0; j <
n; j++)
543 dot_sum += x[ib][j] *
A[ib][j *
lda + i];
544 if (beta[ib] == T(0))
545 y[ib][i * incy] = alpha[ib] * dot_sum;
547 y[ib][i * incy] = alpha[ib] * dot_sum + beta[ib] * y[ib][i * incy];
552 throw std::runtime_error(
"Error: trans=='C' not yet implemented for ompBLAS::gemv_impl.");
561 const float*
const A[],
563 const float*
const x[],
568 const int batch_count)
570 return gemv_batched_impl(handle, trans,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy, batch_count);
579 const double*
const A[],
581 const double*
const x[],
586 const int batch_count)
588 return gemv_batched_impl(handle, trans,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy, batch_count);
591 #if !defined(OPENMP_NO_COMPLEX) 597 const std::complex<float>* alpha,
598 const std::complex<float>*
const A[],
600 const std::complex<float>*
const x[],
602 const std::complex<float>* beta,
603 std::complex<float>*
const y[],
605 const int batch_count)
607 return gemv_batched_impl(handle, trans,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy, batch_count);
615 const std::complex<double>* alpha,
616 const std::complex<double>*
const A[],
618 const std::complex<double>*
const x[],
620 const std::complex<double>* beta,
621 std::complex<double>*
const y[],
623 const int batch_count)
625 return gemv_batched_impl(handle, trans,
m,
n, alpha,
A,
lda, x, incx, beta, y, incy, batch_count);
642 if (
m == 0 ||
n == 0)
645 if (incx != 1 || incy != 1)
646 throw std::runtime_error(
"incx !=1 or incy != 1 are not implemented in ompBLAS::ger_impl!");
649 PRAGMA_OFFLOAD(
"omp target teams distribute parallel for collapse(2) is_device_ptr(A, x, y)")
650 for (uint32_t i = 0; i <
n; i++)
651 for (uint32_t j = 0; j <
m; j++)
652 A[i *
lda + j] += alpha * x[j] * y[i];
661 const float*
const x,
663 const float*
const y,
676 const double*
const x,
678 const double*
const y,
686 #if !defined(OPENMP_NO_COMPLEX) 691 const std::complex<float> alpha,
692 const std::complex<float>*
const x,
694 const std::complex<float>*
const y,
696 std::complex<float>*
const A,
706 const std::complex<double> alpha,
707 const std::complex<double>*
const x,
709 const std::complex<double>*
const y,
711 std::complex<double>*
const A,
730 const int batch_count)
732 if (
m == 0 ||
n == 0 || batch_count == 0)
737 throw std::runtime_error(
"incx!=1 are not implemented in ompBLAS::ger_batched_impl!");
739 PRAGMA_OFFLOAD(
"omp target teams distribute parallel for collapse(3) is_device_ptr(A, x, y, alpha)")
740 for (uint32_t ib = 0; ib < batch_count; ib++)
741 for (uint32_t i = 0; i <
n; i++)
742 for (uint32_t j = 0; j <
m; j++)
743 A[ib][i *
lda + j] += alpha[ib] * x[ib][j] * y[ib][i * incy];
752 const float*
const x[],
754 const float*
const y[],
758 const int batch_count)
760 return ger_batched_impl(handle,
m,
n, alpha, x, incx, y, incy,
A,
lda, batch_count);
768 const double*
const x[],
770 const double*
const y[],
774 const int batch_count)
776 return ger_batched_impl(handle,
m,
n, alpha, x, incx, y, incy,
A,
lda, batch_count);
779 #if !defined(OPENMP_NO_COMPLEX) 784 const std::complex<float>* alpha,
785 const std::complex<float>*
const x[],
787 const std::complex<float>*
const y[],
789 std::complex<float>*
const A[],
791 const int batch_count)
793 return ger_batched_impl(handle,
m,
n, alpha, x, incx, y, incy,
A,
lda, batch_count);
800 const std::complex<double>* alpha,
801 const std::complex<double>*
const x[],
803 const std::complex<double>*
const y[],
805 std::complex<double>*
const A[],
807 const int batch_count)
809 return ger_batched_impl(handle,
m,
n, alpha, x, incx, y, incy,
A,
lda, batch_count);
821 const int batch_count)
823 if (
n == 0 || batch_count == 0)
826 PRAGMA_OFFLOAD(
"omp target teams distribute parallel for collapse(2) is_device_ptr(x, y)")
827 for (uint32_t ib = 0; ib < batch_count; ib++)
828 for (uint32_t i = 0; i <
n; i++)
829 y[ib][i * incy] = x[ib][i * incx];
836 const float*
const x[],
840 const int batch_count)
848 const double*
const x[],
852 const int batch_count)
857 #if !defined(OPENMP_NO_COMPLEX) 861 const std::complex<float>*
const x[],
863 std::complex<float>*
const y[],
865 const int batch_count)
873 const std::complex<double>*
const x[],
875 std::complex<double>*
const y[],
877 const int batch_count)
892 const int batch_count)
894 if (
n == 0 || batch_count == 0)
897 PRAGMA_OFFLOAD(
"omp target teams distribute parallel for collapse(2) is_device_ptr(x, y)")
898 for (uint32_t ib = 0; ib < batch_count; ib++)
899 for (uint32_t i = 0; i <
n; i++)
900 y[ib][y_offset + i * incy] = x[ib][x_offset + i * incx];
907 const float*
const x[],
913 const int batch_count)
921 const double*
const x[],
927 const int batch_count)
932 #if !defined(OPENMP_NO_COMPLEX) 936 const std::complex<float>*
const x[],
939 std::complex<float>*
const y[],
942 const int batch_count)
950 const std::complex<double>*
const x[],
953 std::complex<double>*
const y[],
956 const int batch_count)
972 PRAGMA_OFFLOAD(
"omp target teams distribute parallel for is_device_ptr(x, y)")
973 for (
size_t i = 0; i <
n; i++)
974 y[i * incy] = x[i * incx];
981 const float*
const x,
986 return copy_impl(handle,
n, x, incx, y, incy);
992 const double*
const x,
997 return copy_impl(handle,
n, x, incx, y, incy);
1003 const std::complex<float>*
const x,
1005 std::complex<float>*
const y,
1008 return copy_impl(handle,
n, x, incx, y, incy);
1014 const std::complex<double>*
const x,
1016 std::complex<double>*
const y,
1019 return copy_impl(handle,
n, x, incx, y, incy);
ompBLAS_status ger_batched< float >(ompBLAS_handle &handle, const int m, const int n, const float *alpha, const float *const x[], const int incx, const float *const y[], const int incy, float *const A[], const int lda, const int batch_count)
ompBLAS_status copy_batched_offset< float >(ompBLAS_handle &handle, const int n, const float *const x[], const int x_offset, const int incx, float *const y[], const int y_offset, const int incy, const int batch_count)
ompBLAS_status copy_batched< double >(ompBLAS_handle &handle, const int n, const double *const x[], const int incx, double *const y[], const int incy, const int batch_count)
ompBLAS_status gemm_batched_impl(ompBLAS_handle &handle, const char transa, const char transb, const int M, const int N, const int K, const T alpha, const T *const Aarray[], const int lda, const T *const Barray[], const int ldb, const T beta, T *const Carray[], const int ldc, const int batch_count)
helper functions for EinsplineSetBuilder
ompBLAS_status ger_impl(ompBLAS_handle &handle, const int m, const int n, const T alpha, const T *const x, const int incx, const T *const y, const int incy, T *const A, const int lda)
ompBLAS_status gemm_impl(ompBLAS_handle &handle, const char transa, const char transb, const int M, const int N, const int K, const T &alpha, const T *const A, const int lda, const T *const B, const int ldb, const T &beta, T *const C, const int ldc)
ompBLAS_status copy_batched_impl(ompBLAS_handle &handle, const int n, const T *const x[], const int incx, T *const y[], const int incy, const int batch_count)
ompBLAS_status copy< double >(ompBLAS_handle &handle, const int n, const double *const x, const int incx, double *const y, const int incy)
ompBLAS_status ger_batched_impl(ompBLAS_handle &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const int batch_count)
ompBLAS_status copy_batched_offset_impl(ompBLAS_handle &handle, const int n, const T *const x[], const int x_offset, const int incx, T *const y[], const int y_offset, const int incy, const int batch_count)
ompBLAS_status gemv_batched< float >(ompBLAS_handle &handle, const char trans, const int m, const int n, const float *alpha, const float *const A[], const int lda, const float *const x[], const int incx, const float *beta, float *const y[], const int incy, const int batch_count)
ompBLAS_status gemm< float >(ompBLAS_handle &handle, const char transa, const char transb, const int M, const int N, const int K, const float &alpha, const float *const A, const int lda, const float *const B, const int ldb, const float &beta, float *const C, const int ldc)
ompBLAS_status gemv< double >(ompBLAS_handle &handle, const char trans, const int m, const int n, const double alpha, const double *const A, const int lda, const double *const x, const int incx, const double beta, double *const y, const int incy)
ompBLAS_status copy< float >(ompBLAS_handle &handle, const int n, const float *const x, const int incx, float *const y, const int incy)
ompBLAS_status gemv_impl(ompBLAS_handle &handle, const char trans, const int m, const int n, const T alpha, const T *const A, const int lda, const T *const x, const int incx, const T beta, T *const y, const int incy)
ompBLAS_status copy_impl(ompBLAS_handle &handle, const int n, const T *const x, const int incx, T *const y, const int incy)
ompBLAS_status gemm_batched< double >(ompBLAS_handle &handle, const char transa, const char transb, const int M, const int N, const int K, const double &alpha, const double *const A[], const int lda, const double *const B[], const int ldb, const double &beta, double *const C[], const int ldc, const int batch_count)
ompBLAS_status gemv_batched_impl(ompBLAS_handle &handle, const char trans, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const int batch_count)
ompBLAS_status ger< float >(ompBLAS_handle &handle, const int m, const int n, const float alpha, const float *const x, const int incx, const float *const y, const int incy, float *const A, const int lda)
ompBLAS_status ger_batched< double >(ompBLAS_handle &handle, const int m, const int n, const double *alpha, const double *const x[], const int incx, const double *const y[], const int incy, double *const A[], const int lda, const int batch_count)
ompBLAS_status gemm_batched< float >(ompBLAS_handle &handle, const char transa, const char transb, const int M, const int N, const int K, const float &alpha, const float *const A[], const int lda, const float *const B[], const int ldb, const float &beta, float *const C[], const int ldc, const int batch_count)
ompBLAS_status ger< double >(ompBLAS_handle &handle, const int m, const int n, const double alpha, const double *const x, const int incx, const double *const y, const int incy, double *const A, const int lda)
ompBLAS_status gemm< double >(ompBLAS_handle &handle, const char transa, const char transb, const int M, const int N, const int K, const double &alpha, const double *const A, const int lda, const double *const B, const int ldb, const double &beta, double *const C, const int ldc)
ompBLAS_status copy_batched< float >(ompBLAS_handle &handle, const int n, const float *const x[], const int incx, float *const y[], const int incy, const int batch_count)
double B(double x, int k, int i, const std::vector< double > &t)
ompBLAS_status copy_batched_offset< double >(ompBLAS_handle &handle, const int n, const double *const x[], const int x_offset, const int incx, double *const y[], const int y_offset, const int incy, const int batch_count)
ompBLAS_status gemv< float >(ompBLAS_handle &handle, const char trans, const int m, const int n, const float alpha, const float *const A, const int lda, const float *const x, const int incx, const float beta, float *const y, const int incy)
ompBLAS_status gemv_batched< double >(ompBLAS_handle &handle, const char trans, const int m, const int n, const double *alpha, const double *const A[], const int lda, const double *const x[], const int incx, const double *beta, double *const y[], const int incy, const int batch_count)