QMCPACK
test_ompBLAS.cpp
Go to the documentation of this file.
1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2021 QMCPACK developers.
6 //
7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8 //
9 // File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
10 //////////////////////////////////////////////////////////////////////////////////////
11 
12 #include "catch.hpp"
13 
14 #include <memory>
15 #include <vector>
16 #include <iostream>
18 #include "OMPTarget/ompBLAS.hpp"
19 #include <OhmmsPETE/OhmmsVector.h>
20 #include <OhmmsPETE/OhmmsMatrix.h>
21 #include <CPU/BLAS.hpp>
22 
23 namespace qmcplusplus
24 {
25 template<typename T>
26 void test_gemm(const int M, const int N, const int K, const char transa, const char transb)
27 {
28  const int a0 = transa == 'T' ? M : K;
29  const int a1 = transa == 'T' ? K : M;
30 
31  const int b0 = transb == 'T' ? K : N;
32  const int b1 = transb == 'T' ? N : K;
33 
35  using mat_t = Matrix<T, OMPallocator<T>>;
36 
38 
39  mat_t A(a0, a1); // Input matrix
40  mat_t B(b0, b1); // Input matrix
41  mat_t C(N, M); // Result matrix ompBLAS
42  mat_t D(N, M); // Result matrix BLAS
43 
44  // Fill data
45  for (int j = 0; j < a0; j++)
46  for (int i = 0; i < a1; i++)
47  A[j][i] = i * 3 + j * 4;
48 
49  for (int j = 0; j < b0; j++)
50  for (int i = 0; i < b1; i++)
51  B[j][i] = i * 4 + j * 5;
52 
53  A.updateTo();
54  B.updateTo();
55 
56  T alpha(1);
57  T alpha_half(0.5);
58  T beta(0);
59  T beta1(1);
60 
61  // U[X,Y] denotes a row-major matrix U with X rows and Y cols
62  // element U(i,j) is located at: U.data() + sizeof(U_type) * (i*ldU + j)
63  //
64  // A,B,C,D are treated as row-major matrices, but the arguments to gemm are treated as col-major
65  // so the call below to ompBLAS::gemm is equivalent to one of the following (with row-major matrices)
66  // transa/transb == 'N'/'N': C[N,M] = A[K,M] * B[N,K]; C = B * A
67  // transa/transb == 'N'/'T': C[N,M] = A[K,M] * B[K,N]; C = B^t * A
68  // transa/transb == 'T'/'N': C[N,M] = A[M,K] * B[N,K]; C = B * A^t
69  // transa/transb == 'T'/'T': C[N,M] = A[M,K] * B[K,N]; C = B^t * A^t
70 
71  // alpha 0.5, beta 0
72  ompBLAS::gemm(handle, transa, transb, M, N, K, alpha_half, A.device_data(), a1, B.device_data(), b1, beta,
73  C.device_data(), M);
74  // alpha 0.5, beta 1
75  ompBLAS::gemm(handle, transa, transb, M, N, K, alpha_half, A.device_data(), a1, B.device_data(), b1, beta1,
76  C.device_data(), M);
77  C.updateFrom();
78 
79  BLAS::gemm(transa, transb, M, N, K, alpha, A.data(), a1, B.data(), b1, beta, D.data(), M);
80 
81  for (int j = 0; j < N; j++)
82  for (int i = 0; i < M; i++)
83  {
84  CHECK(std::real(C[j][i]) == Approx(std::real(D[j][i])));
85  CHECK(std::imag(C[j][i]) == Approx(std::imag(D[j][i])));
86  }
87 
88  mat_t A2(a0, a1); // Input matrix
89  mat_t B2(b0, b1); // Input matrix
90  mat_t C2(N, M); // Result matrix ompBLAS
91  mat_t D2(N, M); // Result matrix BLAS
92 
93  // Fill data
94  for (int j = 0; j < a0; j++)
95  for (int i = 0; i < a1; i++)
96  A2[j][i] = j * 3 + i * 4;
97 
98  for (int j = 0; j < b0; j++)
99  for (int i = 0; i < b1; i++)
100  B2[j][i] = j * 4 + i * 5;
101 
102  A2.updateTo();
103  B2.updateTo();
104 
105  Vector<const T*, OMPallocator<const T*>> Aarr(2), Barr(2);
107 
108  Aarr[0] = A2.device_data();
109  Aarr[1] = A.device_data();
110  Barr[0] = B2.device_data();
111  Barr[1] = B.device_data();
112 
113  Carr[0] = C.device_data();
114  Carr[1] = C2.device_data();
115 
116  Aarr.updateTo();
117  Barr.updateTo();
118  Carr.updateTo();
119 
120  // alpha 0.5, beta 0
121  ompBLAS::gemm_batched(handle, transa, transb, M, N, K, alpha_half, Aarr.device_data(), a1, Barr.device_data(), b1,
122  beta, Carr.device_data(), M, 2);
123  // alpha 0.5, beta 1
124  ompBLAS::gemm_batched(handle, transa, transb, M, N, K, alpha_half, Aarr.device_data(), a1, Barr.device_data(), b1,
125  beta1, Carr.device_data(), M, 2);
126  C.updateFrom();
127  C2.updateFrom();
128 
129  BLAS::gemm(transa, transb, M, N, K, alpha, A2.data(), a1, B2.data(), b1, beta, D2.data(), M);
130 
131  for (int j = 0; j < N; j++)
132  for (int i = 0; i < M; i++)
133  {
134  CHECK(std::real(C2[j][i]) == Approx(std::real(D[j][i])));
135  CHECK(std::imag(C2[j][i]) == Approx(std::imag(D[j][i])));
136  }
137 
138  for (int j = 0; j < N; j++)
139  for (int i = 0; i < M; i++)
140  {
141  CHECK(std::real(C[j][i]) == Approx(std::real(D2[j][i])));
142  CHECK(std::imag(C[j][i]) == Approx(std::imag(D2[j][i])));
143  }
144 }
145 
146 TEST_CASE("ompBLAS gemm", "[OMP]")
147 {
148  const int M = 37;
149  const int N = 71;
150  const int K = 23;
151 
152  // Non-batched test
153  std::cout << "Testing NN gemm" << std::endl;
154  test_gemm<float>(M, N, K, 'N', 'N');
155  test_gemm<double>(M, N, K, 'N', 'N');
156 #if defined(QMC_COMPLEX)
157  test_gemm<std::complex<float>>(N, M, K, 'N', 'N');
158  test_gemm<std::complex<double>>(N, M, K, 'N', 'N');
159 #endif
160  std::cout << "Testing NT gemm" << std::endl;
161  test_gemm<float>(M, N, K, 'N', 'T');
162  test_gemm<double>(M, N, K, 'N', 'T');
163 #if defined(QMC_COMPLEX)
164  test_gemm<std::complex<float>>(N, M, K, 'N', 'T');
165  test_gemm<std::complex<double>>(N, M, K, 'N', 'T');
166 #endif
167  std::cout << "Testing TN gemm" << std::endl;
168  test_gemm<float>(M, N, K, 'T', 'N');
169  test_gemm<double>(M, N, K, 'T', 'N');
170 #if defined(QMC_COMPLEX)
171  test_gemm<std::complex<float>>(N, M, K, 'T', 'N');
172  test_gemm<std::complex<double>>(N, M, K, 'T', 'N');
173 #endif
174  std::cout << "Testing TT gemm" << std::endl;
175  test_gemm<float>(M, N, K, 'T', 'T');
176  test_gemm<double>(M, N, K, 'T', 'T');
177 #if defined(QMC_COMPLEX)
178  test_gemm<std::complex<float>>(N, M, K, 'T', 'T');
179  test_gemm<std::complex<double>>(N, M, K, 'T', 'T');
180 #endif
181 }
182 
183 template<typename T>
184 void test_gemv(const int M_b, const int N_b, const char trans)
185 {
186  const int M = trans == 'T' ? M_b : N_b;
187  const int N = trans == 'T' ? N_b : M_b;
188 
190  using mat_t = Matrix<T, OMPallocator<T>>;
191 
193 
194  vec_t A(N); // Input vector
195  mat_t B(M_b, N_b); // Input matrix
196  vec_t C(M); // Result vector ompBLAS
197  vec_t D(M); // Result vector BLAS
198 
199  // Fill data
200  for (int i = 0; i < N; i++)
201  A[i] = i;
202 
203  for (int j = 0; j < M_b; j++)
204  for (int i = 0; i < N_b; i++)
205  B[j][i] = i + j * 2;
206 
207  // Fill C and D with 0
208  for (int i = 0; i < M; i++)
209  C[i] = D[i] = T(0);
210 
211  A.updateTo();
212  B.updateTo();
213  C.updateTo();
214 
215  T alpha(1);
216  T beta(0);
217 
218  // in Fortran, B[M][N] is viewed as B^T
219  // when trans == 'T', the actual calculation is B * A[N] = C[M]
220  // when trans == 'N', the actual calculation is B^T * A[M] = C[N]
221  ompBLAS::gemv(handle, trans, N_b, M_b, alpha, B.device_data(), N_b, A.device_data(), 1, beta, C.device_data(), 1);
222  C.updateFrom();
223 
224  if (trans == 'T')
225  BLAS::gemv_trans(M_b, N_b, B.data(), A.data(), D.data());
226  else
227  BLAS::gemv(M_b, N_b, B.data(), A.data(), D.data());
228 
229  for (int index = 0; index < M; index++)
230  CHECK(C[index] == D[index]);
231 }
232 
233 template<typename T>
234 void test_gemv_batched(const int M_b, const int N_b, const char trans, const int batch_count)
235 {
236  const int M = trans == 'T' ? M_b : N_b;
237  const int N = trans == 'T' ? N_b : M_b;
238 
240  using mat_t = Matrix<T, OMPallocator<T>>;
241 
243 
244  // Create input vector
245  std::vector<vec_t> As;
247 
248  // Create input matrix
249  std::vector<mat_t> Bs;
251 
252  // Create output vector (ompBLAS)
253  std::vector<vec_t> Cs;
255 
256  // Create output vector (BLAS)
257  std::vector<vec_t> Ds;
259 
260  // Resize pointer vectors
261  Aptrs.resize(batch_count);
262  Bptrs.resize(batch_count);
263  Cptrs.resize(batch_count);
264  Dptrs.resize(batch_count);
265 
266  // Resize data vectors
267  As.resize(batch_count);
268  Bs.resize(batch_count);
269  Cs.resize(batch_count);
270  Ds.resize(batch_count);
271 
272  // Fill data
273  for (int batch = 0; batch < batch_count; batch++)
274  {
275  handle = batch;
276 
277  As[batch].resize(N);
278  Aptrs[batch] = As[batch].device_data();
279 
280  Bs[batch].resize(M_b, N_b);
281  Bptrs[batch] = Bs[batch].device_data();
282 
283  Cs[batch].resize(M);
284  Cptrs[batch] = Cs[batch].device_data();
285 
286  Ds[batch].resize(M);
287  Dptrs[batch] = Ds[batch].data();
288 
289  for (int i = 0; i < N; i++)
290  As[batch][i] = i;
291 
292  for (int j = 0; j < M_b; j++)
293  for (int i = 0; i < N_b; i++)
294  Bs[batch][j][i] = i + j * 2;
295 
296  for (int i = 0; i < M; i++)
297  Cs[batch][i] = Ds[batch][i] = T(0);
298 
299  As[batch].updateTo();
300  Bs[batch].updateTo();
301  }
302 
303  Aptrs.updateTo();
304  Bptrs.updateTo();
305  Cptrs.updateTo();
306 
307  // Run tests
308  Vector<T, OMPallocator<T>> alpha(batch_count);
309  Vector<T, OMPallocator<T>> beta(batch_count);
310  Vector<T, OMPallocator<T>> beta1(batch_count);
311 
312  for (int batch = 0; batch < batch_count; batch++)
313  {
314  alpha[batch] = T(0.5);
315  beta[batch] = T(0);
316  beta1[batch] = T(1);
317  }
318 
319  alpha.updateTo();
320  beta.updateTo();
321  beta1.updateTo();
322 
323  // alpha 0.5, beta 0
324  ompBLAS::gemv_batched(handle, trans, N_b, M_b, alpha.device_data(), Bptrs.device_data(), N_b, Aptrs.device_data(), 1,
325  beta.device_data(), Cptrs.device_data(), 1, batch_count);
326  // alpha 0.5, beta 1
327  ompBLAS::gemv_batched(handle, trans, N_b, M_b, alpha.device_data(), Bptrs.device_data(), N_b, Aptrs.device_data(), 1,
328  beta1.device_data(), Cptrs.device_data(), 1, batch_count);
329 
330  for (int batch = 0; batch < batch_count; batch++)
331  {
332  Cs[batch].updateFrom();
333  if (trans == 'T')
334  BLAS::gemv_trans(M_b, N_b, Bs[batch].data(), As[batch].data(), Ds[batch].data());
335  else
336  BLAS::gemv(M_b, N_b, Bs[batch].data(), As[batch].data(), Ds[batch].data());
337 
338  // Check results
339  for (int index = 0; index < M; index++)
340  CHECK(Cs[batch][index] == Ds[batch][index]);
341  }
342 }
343 
344 TEST_CASE("ompBLAS gemv", "[OMP]")
345 {
346  const int M = 137;
347  const int N = 79;
348  const int batch_count = 23;
349 
350  // Non-batched test
351  std::cout << "Testing TRANS gemv" << std::endl;
352  test_gemv<float>(M, N, 'T');
353  test_gemv<double>(M, N, 'T');
354 #if defined(QMC_COMPLEX)
355  test_gemv<std::complex<float>>(N, M, 'T');
356  test_gemv<std::complex<double>>(N, M, 'T');
357 #endif
358  // Batched Test
359  std::cout << "Testing TRANS gemv_batched" << std::endl;
360  test_gemv_batched<float>(M, N, 'T', batch_count);
361  test_gemv_batched<double>(M, N, 'T', batch_count);
362 #if defined(QMC_COMPLEX)
363  test_gemv_batched<std::complex<float>>(N, M, 'T', batch_count);
364  test_gemv_batched<std::complex<double>>(N, M, 'T', batch_count);
365 #endif
366 }
367 
368 TEST_CASE("ompBLAS gemv notrans", "[OMP]")
369 {
370  const int M = 137;
371  const int N = 79;
372  const int batch_count = 23;
373 
374  // Non-batched test
375  std::cout << "Testing NOTRANS gemv" << std::endl;
376  test_gemv<float>(M, N, 'N');
377  test_gemv<double>(M, N, 'N');
378 #if defined(QMC_COMPLEX)
379  test_gemv<std::complex<float>>(N, M, 'N');
380  test_gemv<std::complex<double>>(N, M, 'N');
381 #endif
382  // Batched Test
383  std::cout << "Testing NOTRANS gemv_batched" << std::endl;
384  test_gemv_batched<float>(M, N, 'N', batch_count);
385  test_gemv_batched<double>(M, N, 'N', batch_count);
386 #if defined(QMC_COMPLEX)
387  test_gemv_batched<std::complex<float>>(N, M, 'N', batch_count);
388  test_gemv_batched<std::complex<double>>(N, M, 'N', batch_count);
389 #endif
390 }
391 
392 template<typename T>
393 void test_ger(const int M, const int N)
394 {
396  using mat_t = Matrix<T, OMPallocator<T>>;
397 
399 
400  mat_t Ah(M, N); // Input matrix
401  mat_t Ad(M, N); // Input matrix
402  vec_t x(M); // Input vector
403  vec_t y(N); // Input vector
404 
405  // Fill data
406  for (int i = 0; i < M; i++)
407  x[i] = i;
408  for (int i = 0; i < N; i++)
409  y[i] = N - i;
410 
411  for (int j = 0; j < M; j++)
412  for (int i = 0; i < N; i++)
413  {
414  Ah[j][i] = i + j * 2;
415  Ad[j][i] = i + j * 2;
416  }
417 
418  Ad.updateTo();
419  x.updateTo();
420  y.updateTo();
421 
422  T alpha(1);
423 
424  // in Fortran, B[M][N] is viewed as B^T
425  ompBLAS::ger(handle, M, N, alpha, x.device_data(), 1, y.device_data(), 1, Ad.device_data(), M);
426  Ad.updateFrom();
427 
428  BLAS::ger(M, N, alpha, x.data(), 1, y.data(), 1, Ah.data(), M);
429 
430  for (int j = 0; j < M; j++)
431  for (int i = 0; i < N; i++)
432  CHECK(Ah[j][i] == Ad[j][i]);
433 }
434 
435 template<typename T>
436 void test_ger_batched(const int M, const int N, const int batch_count)
437 {
439  using mat_t = Matrix<T, OMPallocator<T>>;
440 
442 
443  // Create input vector
444  std::vector<vec_t> Xs;
446  std::vector<vec_t> Ys;
448 
449  // Create input matrix
450  std::vector<mat_t> Ahs;
452  std::vector<mat_t> Ads;
454 
455  // Resize pointer vectors
456  Xptrs.resize(batch_count);
457  Yptrs.resize(batch_count);
458  Ahptrs.resize(batch_count);
459  Adptrs.resize(batch_count);
460 
461  // Resize data vectors
462  Xs.resize(batch_count);
463  Ys.resize(batch_count);
464  Ahs.resize(batch_count);
465  Ads.resize(batch_count);
466 
467  // Fill data
468  for (int batch = 0; batch < batch_count; batch++)
469  {
470  handle = batch;
471 
472  Xs[batch].resize(M);
473  Xptrs[batch] = Xs[batch].device_data();
474 
475  Ys[batch].resize(N);
476  Yptrs[batch] = Ys[batch].device_data();
477 
478  Ads[batch].resize(M, N);
479  Adptrs[batch] = Ads[batch].device_data();
480 
481  Ahs[batch].resize(M, N);
482  Ahptrs[batch] = Ahs[batch].data();
483 
484  // Fill data
485  for (int i = 0; i < M; i++)
486  Xs[batch][i] = i;
487  for (int i = 0; i < N; i++)
488  Ys[batch][i] = N - i;
489 
490  for (int j = 0; j < M; j++)
491  for (int i = 0; i < N; i++)
492  {
493  Ads[batch][j][i] = i + j * 2;
494  Ahs[batch][j][i] = i + j * 2;
495  }
496 
497  Xs[batch].updateTo();
498  Ys[batch].updateTo();
499  Ads[batch].updateTo();
500  }
501 
502  Adptrs.updateTo();
503  Xptrs.updateTo();
504  Yptrs.updateTo();
505 
506  // Run tests
507  Vector<T, OMPallocator<T>> alpha(batch_count);
508 
509  for (int batch = 0; batch < batch_count; batch++)
510  {
511  alpha[batch] = T(1);
512  }
513 
514  alpha.updateTo();
515 
516  ompBLAS::ger_batched(handle, M, N, alpha.device_data(), Xptrs.device_data(), 1, Yptrs.device_data(), 1,
517  Adptrs.device_data(), M, batch_count);
518 
519  for (int batch = 0; batch < batch_count; batch++)
520  {
521  Ads[batch].updateFrom();
522  BLAS::ger(M, N, alpha[batch], Xs[batch].data(), 1, Ys[batch].data(), 1, Ahs[batch].data(), M);
523 
524  // Check results
525  for (int j = 0; j < M; j++)
526  for (int i = 0; i < N; i++)
527  CHECK(Ads[batch][j][i] == Ahs[batch][j][i]);
528  }
529 }
530 
531 TEST_CASE("ompBLAS ger", "[OMP]")
532 {
533  const int M = 137;
534  const int N = 79;
535  const int batch_count = 23;
536 
537  // Non-batched test
538  std::cout << "Testing ger" << std::endl;
539  test_ger<float>(M, N);
540  test_ger<double>(M, N);
541 #if defined(QMC_COMPLEX)
542  test_ger<std::complex<float>>(N, M);
543  test_ger<std::complex<double>>(N, M);
544 #endif
545  // Batched Test
546  std::cout << "Testing ger_batched" << std::endl;
547  test_ger_batched<float>(M, N, batch_count);
548  test_ger_batched<double>(M, N, batch_count);
549 #if defined(QMC_COMPLEX)
550  test_ger_batched<std::complex<float>>(N, M, batch_count);
551  test_ger_batched<std::complex<double>>(N, M, batch_count);
552 #endif
553 }
554 } // namespace qmcplusplus
void resize(size_type n, Type_t val=Type_t())
Resize the container.
Definition: OhmmsVector.h:166
void test_ger_batched(const int M, const int N, const int batch_count)
ompBLAS_status gemm(ompBLAS_handle &handle, const char transa, const char transb, const int M, const int N, const int K, const T &alpha, const T *const A, const int lda, const T *const B, const int ldb, const T &beta, T *const C, const int ldc)
static void gemv_trans(int n, int m, const double *restrict amat, const double *restrict x, double *restrict y)
Definition: BLAS.hpp:147
ompBLAS_status gemv(ompBLAS_handle &handle, const char trans, const int m, const int n, const T alpha, const T *const A, const int lda, const T *const x, const int incx, const T beta, T *const y, const int incy)
QMCTraits::RealType real
helper functions for EinsplineSetBuilder
Definition: Configuration.h:43
TEST_CASE("complex_helper", "[type_traits]")
pointer device_data()
Return the device_ptr matching X if this is a vector attached or owning dual space memory...
Definition: OhmmsVector.h:245
void test_ger(const int M, const int N)
static void gemv(int n, int m, const double *restrict amat, const double *restrict x, double *restrict y)
Definition: BLAS.hpp:118
float imag(const float &c)
imaginary part of a scalar. Cannot be replaced by std::imag due to AFQMC specific needs...
ompBLAS_status gemv_batched(ompBLAS_handle &handle, const char trans, const int m, const int n, const T *alpha, const T *const A[], const int lda, const T *const x[], const int incx, const T *beta, T *const y[], const int incy, const int batch_count)
static void ger(int m, int n, double alpha, const double *x, int incx, const double *y, int incy, double *a, int lda)
Definition: BLAS.hpp:437
ompBLAS_status gemm_batched(ompBLAS_handle &handle, const char transa, const char transb, const int M, const int N, const int K, const T &alpha, const T *const A[], const int lda, const T *const B[], const int ldb, const T &beta, T *const C[], const int ldc, const int batch_count)
ompBLAS_status ger_batched(ompBLAS_handle &handle, const int m, const int n, const T *alpha, const T *const x[], const int incx, const T *const y[], const int incy, T *const A[], const int lda, const int batch_count)
Declaraton of Vector<T,Alloc> Manage memory through Alloc directly and allow referencing an existing ...
CHECK(log_values[0]==ComplexApprox(std::complex< double >{ 5.603777579195571, -6.1586603331188225 }))
void test_gemm(const int M, const int N, const int K, const char transa, const char transb)
ompBLAS_status ger(ompBLAS_handle &handle, const int m, const int n, const T alpha, const T *const x, const int incx, const T *const y, const int incy, T *const A, const int lda)
void test_gemv_batched(const int M_b, const int N_b, const char trans, const int batch_count)
static void gemm(char Atrans, char Btrans, int M, int N, int K, double alpha, const double *A, int lda, const double *restrict B, int ldb, double beta, double *restrict C, int ldc)
Definition: BLAS.hpp:235
void updateTo(size_type size=0, std::ptrdiff_t offset=0)
Definition: OhmmsVector.h:263
double B(double x, int k, int i, const std::vector< double > &t)
void test_gemv(const int M_b, const int N_b, const char trans)