22 template<PlatformKind PL,
typename T>
23 void test_one_gemm(
const int M,
const int N,
const int K,
const char transa,
const char transb)
25 const int a0 = transa ==
'T' ? M :
K;
26 const int a1 = transa ==
'T' ?
K : M;
28 const int b0 = transb ==
'T' ?
K :
N;
29 const int b1 = transb ==
'T' ?
N :
K;
40 for (
int j = 0; j < a0; j++)
41 for (
int i = 0; i < a1; i++)
42 A[j][i] = i * 3 + j * 4;
44 for (
int j = 0; j < b0; j++)
45 for (
int i = 0; i < b1; i++)
46 B[j][i] = i * 4 + j * 5;
67 compute::BLAS::gemm(h_blas, transa, transb, M,
N,
K, alpha,
A.device_data(), a1,
B.device_data(), b1, beta,
73 BLAS::gemm(transa, transb, M,
N,
K, alpha,
A.data(), a1,
B.data(), b1, beta, D.data(), M);
75 for (
int j = 0; j <
N; j++)
76 for (
int i = 0; i < M; i++)
88 for (
int j = 0; j < a0; j++)
89 for (
int i = 0; i < a1; i++)
90 A2[j][i] = j * 3 + i * 4;
92 for (
int j = 0; j < b0; j++)
93 for (
int i = 0; i < b1; i++)
94 B2[j][i] = j * 4 + i * 5;
102 Aarr[0] = A2.device_data();
103 Aarr[1] =
A.device_data();
105 Barr[1] =
B.device_data();
107 Carr[0] =
C.device_data();
114 compute::BLAS::gemm_batched(h_blas, transa, transb, M,
N,
K, alpha, Aarr.device_data(), a1, Barr.
device_data(), b1,
121 BLAS::gemm(transa, transb, M,
N,
K, alpha, A2.data(), a1, B2.data(), b1, beta, D2.data(), M);
123 for (
int j = 0; j <
N; j++)
124 for (
int i = 0; i < M; i++)
130 for (
int j = 0; j <
N; j++)
131 for (
int i = 0; i < M; i++)
138 template<PlatformKind PL>
146 std::cout <<
"Testing NN gemm" << std::endl;
147 test_one_gemm<PL, float>(M,
N,
K,
'N',
'N');
148 test_one_gemm<PL, double>(M,
N,
K,
'N',
'N');
149 #if defined(QMC_COMPLEX) 150 test_one_gemm<PL, std::complex<float>>(
N, M,
K,
'N',
'N');
151 test_one_gemm<PL, std::complex<double>>(
N, M,
K,
'N',
'N');
153 std::cout <<
"Testing NT gemm" << std::endl;
154 test_one_gemm<PL, float>(M,
N,
K,
'N',
'T');
155 test_one_gemm<PL, double>(M,
N,
K,
'N',
'T');
156 #if defined(QMC_COMPLEX) 157 test_one_gemm<PL, std::complex<float>>(
N, M,
K,
'N',
'T');
158 test_one_gemm<PL, std::complex<double>>(
N, M,
K,
'N',
'T');
160 std::cout <<
"Testing TN gemm" << std::endl;
161 test_one_gemm<PL, float>(M,
N,
K,
'T',
'N');
162 test_one_gemm<PL, double>(M,
N,
K,
'T',
'N');
163 #if defined(QMC_COMPLEX) 164 test_one_gemm<PL, std::complex<float>>(
N, M,
K,
'T',
'N');
165 test_one_gemm<PL, std::complex<double>>(
N, M,
K,
'T',
'N');
167 std::cout <<
"Testing TT gemm" << std::endl;
168 test_one_gemm<PL, float>(M,
N,
K,
'T',
'T');
169 test_one_gemm<PL, double>(M,
N,
K,
'T',
'T');
170 #if defined(QMC_COMPLEX) 171 test_one_gemm<PL, std::complex<float>>(
N, M,
K,
'T',
'T');
172 test_one_gemm<PL, std::complex<double>>(
N, M,
K,
'T',
'T');
176 template<PlatformKind PL,
typename T>
179 const int M = trans ==
'T' ? M_b : N_b;
180 const int N = trans ==
'T' ? N_b : M_b;
191 for (
int i = 0; i <
N; i++)
194 for (
int j = 0; j < M_b; j++)
195 for (
int i = 0; i < N_b; i++)
199 for (
int i = 0; i < M; i++)
214 compute::BLAS::gemv(h_blas, trans, N_b, M_b, alpha,
B.device_data(), N_b,
A.device_data(), 1, beta,
C.device_data(),
225 for (
int index = 0; index < M; index++)
226 CHECK(
C[index] == D[index] * alpha);
234 for (
int i = 0; i <
N; i++)
237 for (
int j = 0; j < M_b; j++)
238 for (
int i = 0; i < N_b; i++)
239 B2[j][i] = i * 2 + j;
242 for (
int i = 0; i < M; i++)
243 C2[i] = D2[i] = T(0);
253 Aarr[0] = A2.device_data();
254 Aarr[1] =
A.device_data();
256 Barr[1] =
B.device_data();
258 Carr[0] = C2.device_data();
259 Carr[1] =
C.device_data();
269 alpha_arr.updateTo();
287 for (
int index = 0; index < M; index++)
289 CHECK(
C[index] == D[index]);
290 CHECK(C2[index] == D2[index] * alpha_arr[0]);
294 template<PlatformKind PL>
300 std::cout <<
"Testing NOTRANS gemv" << std::endl;
301 test_one_gemv<PL, float>(M,
N,
'N');
302 test_one_gemv<PL, double>(M,
N,
'N');
303 #if defined(QMC_COMPLEX) 304 test_one_gemv<PL, std::complex<float>>(
N, M,
'N');
305 test_one_gemv<PL, std::complex<double>>(
N, M,
'N');
307 std::cout <<
"Testing TRANS gemv" << std::endl;
308 test_one_gemv<PL, float>(M,
N,
'T');
309 test_one_gemv<PL, double>(M,
N,
'T');
310 #if defined(QMC_COMPLEX) 311 test_one_gemv<PL, std::complex<float>>(
N, M,
'T');
312 test_one_gemv<PL, std::complex<double>>(
N, M,
'T');
316 template<PlatformKind PL,
typename T>
328 for (
int i = 0; i < M; i++)
330 for (
int i = 0; i <
N; i++)
333 for (
int j = 0; j < M; j++)
334 for (
int i = 0; i <
N; i++)
336 Ah[j][i] = i + j * 2;
337 Ad[j][i] = i + j * 2;
349 compute::BLAS::ger(h_blas, M,
N, alpha, x.device_data(), 1, y.device_data(), 1, Ad.device_data(), M);
355 for (
int j = 0; j < M; j++)
356 for (
int i = 0; i <
N; i++)
357 CHECK(Ah[j][i] == Ad[j][i]);
365 for (
int i = 0; i < M; i++)
367 for (
int i = 0; i <
N; i++)
370 for (
int j = 0; j < M; j++)
371 for (
int i = 0; i <
N; i++)
373 Ah2[j][i] = j + i * 2;
374 Ad2[j][i] = j + i * 2;
387 Xarr[0] = x2.device_data();
388 Xarr[1] = x.device_data();
389 Yarr[0] = y2.device_data();
390 Yarr[1] = y.device_data();
410 for (
int j = 0; j < M; j++)
411 for (
int i = 0; i <
N; i++)
413 CHECK(Ah[j][i] == Ad[j][i]);
414 CHECK(Ah2[j][i] == Ad2[j][i]);
418 template<PlatformKind PL>
425 std::cout <<
"Testing ger_batched" << std::endl;
426 test_one_ger<PL, float>(M,
N);
427 test_one_ger<PL, double>(M,
N);
428 #if defined(QMC_COMPLEX) 429 test_one_ger<PL, std::complex<float>>(
N, M);
430 test_one_ger<PL, std::complex<double>>(
N, M);
438 #if defined(ENABLE_CUDA) 439 std::cout <<
"Testing gemm<PlatformKind::CUDA>" << std::endl;
440 test_gemm_cases<PlatformKind::CUDA>();
442 #if defined(ENABLE_SYCL) 443 std::cout <<
"Testing gemm<PlatformKind::SYCL>" << std::endl;
444 test_gemm_cases<PlatformKind::SYCL>();
446 #if defined(ENABLE_OFFLOAD) 447 std::cout <<
"Testing gemm<PlatformKind::OMPTARGET>" << std::endl;
448 test_gemm_cases<PlatformKind::OMPTARGET>();
454 #if defined(ENABLE_CUDA) 455 std::cout <<
"Testing gemm<PlatformKind::CUDA>" << std::endl;
456 test_gemv_cases<PlatformKind::CUDA>();
458 #if defined(ENABLE_SYCL) 459 std::cout <<
"Testing gemm<PlatformKind::SYCL>" << std::endl;
460 test_gemv_cases<PlatformKind::SYCL>();
462 #if defined(ENABLE_OFFLOAD) 463 std::cout <<
"Testing gemm<PlatformKind::OMPTARGET>" << std::endl;
464 test_gemv_cases<PlatformKind::OMPTARGET>();
470 #if defined(ENABLE_CUDA) 471 std::cout <<
"Testing ger<PlatformKind::CUDA>" << std::endl;
472 test_ger_cases<PlatformKind::CUDA>();
474 #if defined(ENABLE_SYCL) 475 std::cout <<
"Testing ger<PlatformKind::SYCL>" << std::endl;
476 test_ger_cases<PlatformKind::SYCL>();
478 #if defined(ENABLE_OFFLOAD) 479 std::cout <<
"Testing ger<PlatformKind::OMPTARGET>" << std::endl;
480 test_ger_cases<PlatformKind::OMPTARGET>();
void gemm(BLASHandle< PlatformKind::CUDA > &handle, const char transa, const char transb, int m, int n, int k, const float &alpha, const float *A, int lda, const float *B, int ldb, const float &beta, float *C, int ldc)
static void gemv_trans(int n, int m, const double *restrict amat, const double *restrict x, double *restrict y)
helper functions for EinsplineSetBuilder
TEST_CASE("complex_helper", "[type_traits]")
pointer device_data()
Return the device_ptr matching X if this is a vector attached or owning dual space memory...
static void gemv(int n, int m, const double *restrict amat, const double *restrict x, double *restrict y)
float imag(const float &c)
imaginary part of a scalar. Cannot be replaced by std::imag due to AFQMC specific needs...
void ger(BLASHandle< PlatformKind::CUDA > &handle, const int m, const int n, const float &alpha, const float *const x, const int incx, const float *const y, const int incy, float *const A, const int lda)
void gemv_batched(BLASHandle< PlatformKind::CUDA > &handle, const char trans, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const int batch_count)
These allocators are to make code that should be generic with the respect to accelerator code flavor ...
static void ger(int m, int n, double alpha, const double *x, int incx, const double *y, int incy, double *a, int lda)
void gemm_batched(BLASHandle< PlatformKind::CUDA > &handle, const char transa, const char transb, int m, int n, int k, const float &alpha, const float *const A[], int lda, const float *const B[], int ldb, const float &beta, float *const C[], int ldc, int batchCount)
void test_one_gemv(const int M_b, const int N_b, const char trans)
void test_one_gemm(const int M, const int N, const int K, const char transa, const char transb)
Declaraton of Vector<T,Alloc> Manage memory through Alloc directly and allow referencing an existing ...
CHECK(log_values[0]==ComplexApprox(std::complex< double >{ 5.603777579195571, -6.1586603331188225 }))
void ger_batched(BLASHandle< PlatformKind::CUDA > &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const int batch_count)
static void gemm(char Atrans, char Btrans, int M, int N, int K, double alpha, const double *A, int lda, const double *restrict B, int ldb, double beta, double *restrict C, int ldc)
void updateTo(size_type size=0, std::ptrdiff_t offset=0)
double B(double x, int k, int i, const std::vector< double > &t)
void gemv(BLASHandle< PlatformKind::CUDA > &handle, const char trans, const int m, const int n, const float &alpha, const float *const A, const int lda, const float *const x, const int incx, const float &beta, float *const y, const int incy)
void test_one_ger(const int M, const int N)