26 void test_gemm(
const int M,
const int N,
const int K,
const char transa,
const char transb)
28 const int a0 = transa ==
'T' ? M :
K;
29 const int a1 = transa ==
'T' ?
K : M;
31 const int b0 = transb ==
'T' ?
K :
N;
32 const int b1 = transb ==
'T' ?
N :
K;
45 for (
int j = 0; j < a0; j++)
46 for (
int i = 0; i < a1; i++)
47 A[j][i] = i * 3 + j * 4;
49 for (
int j = 0; j < b0; j++)
50 for (
int i = 0; i < b1; i++)
51 B[j][i] = i * 4 + j * 5;
72 ompBLAS::gemm(handle, transa, transb, M,
N,
K, alpha_half,
A.device_data(), a1,
B.device_data(), b1, beta,
75 ompBLAS::gemm(handle, transa, transb, M,
N,
K, alpha_half,
A.device_data(), a1,
B.device_data(), b1, beta1,
79 BLAS::gemm(transa, transb, M,
N,
K, alpha,
A.data(), a1,
B.data(), b1, beta, D.data(), M);
81 for (
int j = 0; j <
N; j++)
82 for (
int i = 0; i < M; i++)
94 for (
int j = 0; j < a0; j++)
95 for (
int i = 0; i < a1; i++)
96 A2[j][i] = j * 3 + i * 4;
98 for (
int j = 0; j < b0; j++)
99 for (
int i = 0; i < b1; i++)
100 B2[j][i] = j * 4 + i * 5;
108 Aarr[0] = A2.device_data();
109 Aarr[1] =
A.device_data();
111 Barr[1] =
B.device_data();
113 Carr[0] =
C.device_data();
121 ompBLAS::gemm_batched(handle, transa, transb, M,
N,
K, alpha_half, Aarr.device_data(), a1, Barr.
device_data(), b1,
124 ompBLAS::gemm_batched(handle, transa, transb, M,
N,
K, alpha_half, Aarr.device_data(), a1, Barr.
device_data(), b1,
129 BLAS::gemm(transa, transb, M,
N,
K, alpha, A2.data(), a1, B2.data(), b1, beta, D2.data(), M);
131 for (
int j = 0; j <
N; j++)
132 for (
int i = 0; i < M; i++)
138 for (
int j = 0; j <
N; j++)
139 for (
int i = 0; i < M; i++)
153 std::cout <<
"Testing NN gemm" << std::endl;
154 test_gemm<float>(M,
N,
K,
'N',
'N');
155 test_gemm<double>(M,
N,
K,
'N',
'N');
156 #if defined(QMC_COMPLEX) 157 test_gemm<std::complex<float>>(
N, M,
K,
'N',
'N');
158 test_gemm<std::complex<double>>(
N, M,
K,
'N',
'N');
160 std::cout <<
"Testing NT gemm" << std::endl;
161 test_gemm<float>(M,
N,
K,
'N',
'T');
162 test_gemm<double>(M,
N,
K,
'N',
'T');
163 #if defined(QMC_COMPLEX) 164 test_gemm<std::complex<float>>(
N, M,
K,
'N',
'T');
165 test_gemm<std::complex<double>>(
N, M,
K,
'N',
'T');
167 std::cout <<
"Testing TN gemm" << std::endl;
168 test_gemm<float>(M,
N,
K,
'T',
'N');
169 test_gemm<double>(M,
N,
K,
'T',
'N');
170 #if defined(QMC_COMPLEX) 171 test_gemm<std::complex<float>>(
N, M,
K,
'T',
'N');
172 test_gemm<std::complex<double>>(
N, M,
K,
'T',
'N');
174 std::cout <<
"Testing TT gemm" << std::endl;
175 test_gemm<float>(M,
N,
K,
'T',
'T');
176 test_gemm<double>(M,
N,
K,
'T',
'T');
177 #if defined(QMC_COMPLEX) 178 test_gemm<std::complex<float>>(
N, M,
K,
'T',
'T');
179 test_gemm<std::complex<double>>(
N, M,
K,
'T',
'T');
184 void test_gemv(
const int M_b,
const int N_b,
const char trans)
186 const int M = trans ==
'T' ? M_b : N_b;
187 const int N = trans ==
'T' ? N_b : M_b;
200 for (
int i = 0; i <
N; i++)
203 for (
int j = 0; j < M_b; j++)
204 for (
int i = 0; i < N_b; i++)
208 for (
int i = 0; i < M; i++)
221 ompBLAS::gemv(handle, trans, N_b, M_b, alpha,
B.device_data(), N_b,
A.device_data(), 1, beta,
C.device_data(), 1);
229 for (
int index = 0; index < M; index++)
230 CHECK(
C[index] == D[index]);
236 const int M = trans ==
'T' ? M_b : N_b;
237 const int N = trans ==
'T' ? N_b : M_b;
245 std::vector<vec_t> As;
249 std::vector<mat_t> Bs;
253 std::vector<vec_t> Cs;
257 std::vector<vec_t> Ds;
261 Aptrs.
resize(batch_count);
262 Bptrs.
resize(batch_count);
263 Cptrs.
resize(batch_count);
264 Dptrs.
resize(batch_count);
267 As.resize(batch_count);
268 Bs.resize(batch_count);
269 Cs.resize(batch_count);
270 Ds.resize(batch_count);
273 for (
int batch = 0; batch < batch_count; batch++)
280 Bs[batch].resize(M_b, N_b);
287 Dptrs[batch] = Ds[batch].
data();
289 for (
int i = 0; i <
N; i++)
292 for (
int j = 0; j < M_b; j++)
293 for (
int i = 0; i < N_b; i++)
294 Bs[batch][j][i] = i + j * 2;
296 for (
int i = 0; i < M; i++)
297 Cs[batch][i] = Ds[batch][i] = T(0);
299 As[batch].updateTo();
300 Bs[batch].updateTo();
312 for (
int batch = 0; batch < batch_count; batch++)
314 alpha[batch] = T(0.5);
330 for (
int batch = 0; batch < batch_count; batch++)
332 Cs[batch].updateFrom();
334 BLAS::gemv_trans(M_b, N_b, Bs[batch].data(), As[batch].data(), Ds[batch].data());
336 BLAS::gemv(M_b, N_b, Bs[batch].data(), As[batch].data(), Ds[batch].data());
339 for (
int index = 0; index < M; index++)
340 CHECK(Cs[batch][index] == Ds[batch][index]);
348 const int batch_count = 23;
351 std::cout <<
"Testing TRANS gemv" << std::endl;
352 test_gemv<float>(M,
N,
'T');
353 test_gemv<double>(M,
N,
'T');
354 #if defined(QMC_COMPLEX) 355 test_gemv<std::complex<float>>(
N, M,
'T');
356 test_gemv<std::complex<double>>(
N, M,
'T');
359 std::cout <<
"Testing TRANS gemv_batched" << std::endl;
360 test_gemv_batched<float>(M,
N,
'T', batch_count);
361 test_gemv_batched<double>(M,
N,
'T', batch_count);
362 #if defined(QMC_COMPLEX) 363 test_gemv_batched<std::complex<float>>(
N, M,
'T', batch_count);
364 test_gemv_batched<std::complex<double>>(
N, M,
'T', batch_count);
372 const int batch_count = 23;
375 std::cout <<
"Testing NOTRANS gemv" << std::endl;
376 test_gemv<float>(M,
N,
'N');
377 test_gemv<double>(M,
N,
'N');
378 #if defined(QMC_COMPLEX) 379 test_gemv<std::complex<float>>(
N, M,
'N');
380 test_gemv<std::complex<double>>(
N, M,
'N');
383 std::cout <<
"Testing NOTRANS gemv_batched" << std::endl;
384 test_gemv_batched<float>(M,
N,
'N', batch_count);
385 test_gemv_batched<double>(M,
N,
'N', batch_count);
386 #if defined(QMC_COMPLEX) 387 test_gemv_batched<std::complex<float>>(
N, M,
'N', batch_count);
388 test_gemv_batched<std::complex<double>>(
N, M,
'N', batch_count);
406 for (
int i = 0; i < M; i++)
408 for (
int i = 0; i <
N; i++)
411 for (
int j = 0; j < M; j++)
412 for (
int i = 0; i <
N; i++)
414 Ah[j][i] = i + j * 2;
415 Ad[j][i] = i + j * 2;
425 ompBLAS::ger(handle, M,
N, alpha, x.device_data(), 1, y.device_data(), 1, Ad.device_data(), M);
430 for (
int j = 0; j < M; j++)
431 for (
int i = 0; i <
N; i++)
432 CHECK(Ah[j][i] == Ad[j][i]);
444 std::vector<vec_t> Xs;
446 std::vector<vec_t> Ys;
450 std::vector<mat_t> Ahs;
452 std::vector<mat_t> Ads;
456 Xptrs.
resize(batch_count);
457 Yptrs.
resize(batch_count);
458 Ahptrs.
resize(batch_count);
459 Adptrs.
resize(batch_count);
462 Xs.resize(batch_count);
463 Ys.resize(batch_count);
464 Ahs.resize(batch_count);
465 Ads.resize(batch_count);
468 for (
int batch = 0; batch < batch_count; batch++)
478 Ads[batch].resize(M,
N);
481 Ahs[batch].resize(M,
N);
482 Ahptrs[batch] = Ahs[batch].
data();
485 for (
int i = 0; i < M; i++)
487 for (
int i = 0; i <
N; i++)
488 Ys[batch][i] =
N - i;
490 for (
int j = 0; j < M; j++)
491 for (
int i = 0; i <
N; i++)
493 Ads[batch][j][i] = i + j * 2;
494 Ahs[batch][j][i] = i + j * 2;
497 Xs[batch].updateTo();
498 Ys[batch].updateTo();
499 Ads[batch].updateTo();
509 for (
int batch = 0; batch < batch_count; batch++)
519 for (
int batch = 0; batch < batch_count; batch++)
521 Ads[batch].updateFrom();
522 BLAS::ger(M,
N, alpha[batch], Xs[batch].data(), 1, Ys[batch].data(), 1, Ahs[batch].data(), M);
525 for (
int j = 0; j < M; j++)
526 for (
int i = 0; i <
N; i++)
527 CHECK(Ads[batch][j][i] == Ahs[batch][j][i]);
535 const int batch_count = 23;
538 std::cout <<
"Testing ger" << std::endl;
539 test_ger<float>(M,
N);
540 test_ger<double>(M,
N);
541 #if defined(QMC_COMPLEX) 542 test_ger<std::complex<float>>(
N, M);
543 test_ger<std::complex<double>>(
N, M);
546 std::cout <<
"Testing ger_batched" << std::endl;
547 test_ger_batched<float>(M,
N, batch_count);
548 test_ger_batched<double>(M,
N, batch_count);
549 #if defined(QMC_COMPLEX) 550 test_ger_batched<std::complex<float>>(
N, M, batch_count);
551 test_ger_batched<std::complex<double>>(
N, M, batch_count);
void resize(size_type n, Type_t val=Type_t())
Resize the container.
void test_ger_batched(const int M, const int N, const int batch_count)
ompBLAS_status gemm(ompBLAS_handle &handle, const char transa, const char transb, const int M, const int N, const int K, const T &alpha, const T *const A, const int lda, const T *const B, const int ldb, const T &beta, T *const C, const int ldc)
static void gemv_trans(int n, int m, const double *restrict amat, const double *restrict x, double *restrict y)
ompBLAS_status gemv(ompBLAS_handle &handle, const char trans, const int m, const int n, const T alpha, const T *const A, const int lda, const T *const x, const int incx, const T beta, T *const y, const int incy)
helper functions for EinsplineSetBuilder
TEST_CASE("complex_helper", "[type_traits]")
pointer device_data()
Return the device_ptr matching X if this is a vector attached or owning dual space memory...
void test_ger(const int M, const int N)
static void gemv(int n, int m, const double *restrict amat, const double *restrict x, double *restrict y)
float imag(const float &c)
imaginary part of a scalar. Cannot be replaced by std::imag due to AFQMC specific needs...
ompBLAS_status gemv_batched(ompBLAS_handle &handle, const char trans, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const int batch_count)
static void ger(int m, int n, double alpha, const double *x, int incx, const double *y, int incy, double *a, int lda)
ompBLAS_status gemm_batched(ompBLAS_handle &handle, const char transa, const char transb, const int M, const int N, const int K, const T &alpha, const T *const A[], const int lda, const T *const B[], const int ldb, const T &beta, T *const C[], const int ldc, const int batch_count)
ompBLAS_status ger_batched(ompBLAS_handle &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const int batch_count)
Declaraton of Vector<T,Alloc> Manage memory through Alloc directly and allow referencing an existing ...
CHECK(log_values[0]==ComplexApprox(std::complex< double >{ 5.603777579195571, -6.1586603331188225 }))
void test_gemm(const int M, const int N, const int K, const char transa, const char transb)
ompBLAS_status ger(ompBLAS_handle &handle, const int m, const int n, const T alpha, const T *const x, const int incx, const T *const y, const int incy, T *const A, const int lda)
void test_gemv_batched(const int M_b, const int N_b, const char trans, const int batch_count)
static void gemm(char Atrans, char Btrans, int M, int N, int K, double alpha, const double *A, int lda, const double *restrict B, int ldb, double beta, double *restrict C, int ldc)
void updateTo(size_type size=0, std::ptrdiff_t offset=0)
double B(double x, int k, int i, const std::vector< double > &t)
void test_gemv(const int M_b, const int N_b, const char trans)