QMCPACK
rocsolver.hpp
Go to the documentation of this file.
1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2022 QMCPACK developers.
6 //
7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8 // Mark Dewing, mdewing@anl.gov, Argonne National Laboratory
9 //
10 // File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
11 //////////////////////////////////////////////////////////////////////////////////////
12 
13 #ifndef QMCPLUSPLUS_ROCSOLVER_H
14 #define QMCPLUSPLUS_ROCSOLVER_H
15 
16 // Interface to rocSOLVER linear algebra library.
17 // File copied and modified from CUDA/cusolver.hpp
18 
19 #include <rocsolver/rocsolver.h>
20 #include <complex>
21 #include <iostream>
22 #include <string>
23 #include <stdexcept>
24 
25 #define rocsolverErrorCheck(ans, cause) \
26  { \
27  rocsolverAssert((ans), cause, __FILE__, __LINE__); \
28  }
29 /// prints rocsolver error messages. Always use rocsolverErrorCheck macro.
30 inline void rocsolverAssert(rocblas_status code,
31  const std::string& cause,
32  const char* file,
33  int line,
34  bool abort = true)
35 {
36  if (code != rocblas_status_success)
37  {
38  std::string rocsolver_error;
39  switch (code)
40  {
41  case rocblas_status_invalid_handle:
42  rocsolver_error = "rocblas_status_invalid_handle";
43  break;
44  case rocblas_status_not_implemented:
45  rocsolver_error = "rocblas_status_not_implemented";
46  break;
47  case rocblas_status_invalid_pointer:
48  rocsolver_error = "rocblas_status_invalid_pointer";
49  break;
50  case rocblas_status_invalid_size:
51  rocsolver_error = "rocblas_status_invalid_size";
52  break;
53  case rocblas_status_memory_error:
54  rocsolver_error = "rocblas_status_memory_error";
55  break;
56  case rocblas_status_internal_error:
57  rocsolver_error = "rocblas_status_internal_error";
58  break;
59  case rocblas_status_perf_degraded:
60  rocsolver_error = "rocblas_status_perf_degraded";
61  break;
62  case rocblas_status_size_query_mismatch:
63  rocsolver_error = "rocblas_status_size_query_mismatch";
64  break;
65  case rocblas_status_size_increased:
66  rocsolver_error = "rocblas_status_size_increased";
67  break;
68  case rocblas_status_size_unchanged:
69  rocsolver_error = "rocblas_status_size_unchanged";
70  break;
71  case rocblas_status_invalid_value:
72  rocsolver_error = "rocblas_status_invalid_value";
73  break;
74  case rocblas_status_continue:
75  rocsolver_error = "rocblas_status_continue";
76  break;
77  case rocblas_status_check_numerics_fail:
78  rocsolver_error = "rocblas_status_check_numerics_fail";
79  break;
80  default:
81  rocsolver_error = "<unknown>";
82  }
83 
84  std::ostringstream err;
85  err << "rocsolverAssert: " << rocsolver_error << ", file " << file << " , line " << line << std::endl
86  << cause << std::endl;
87  std::cerr << err.str();
88  //if (abort) exit(code);
89  throw std::runtime_error(cause);
90  }
91 }
92 
93 namespace qmcplusplus
94 {
95 /** interface to rocsolver calls for different data types S/C/D/Z
96  */
97 namespace rocsolver
98 {
99 
100 
101 inline rocblas_status getrf(rocblas_handle& handle, int m, int n, double* A, int lda, int* ipiv, int* info)
102 {
103  return rocsolver_dgetrf(handle, m, n, A, lda, ipiv, info);
104 }
105 
106 inline rocblas_status getrf(rocblas_handle& handle,
107  int m,
108  int n,
109  std::complex<double>* A,
110  int lda,
111  int* ipiv,
112  int* info)
113 {
114  return rocsolver_zgetrf(handle, m, n, (rocblas_double_complex*)A, lda, ipiv, info);
115 }
116 
117 inline rocblas_status getrs(rocblas_handle& handle,
118  const rocblas_operation& transa,
119  int m,
120  int n,
121  double* A,
122  int lda,
123  int* ipiv,
124  double* B,
125  int ldb)
126 {
127  return rocsolver_dgetrs(handle, transa, m, n, A, lda, ipiv, B, ldb);
128 }
129 
130 inline rocblas_status getrs(rocblas_handle& handle,
131  const rocblas_operation& transa,
132  int m,
133  int n,
134  std::complex<double>* A,
135  int lda,
136  int* ipiv,
137  std::complex<double>* B,
138  int ldb)
139 {
140  return rocsolver_zgetrs(handle, transa, m, n, (rocblas_double_complex*)A, lda, ipiv, (rocblas_double_complex*)B, ldb);
141 }
142 
143 inline rocblas_status getri(rocblas_handle& handle, int n, double* A, int lda, int* ipiv, int* info)
144 {
145  return rocsolver_dgetri(handle, n, A, lda, ipiv, info);
146 }
147 
148 inline rocblas_status getri(rocblas_handle& handle, int n, std::complex<double>* A, int lda, int* ipiv, int* info)
149 {
150  return rocsolver_zgetri(handle, n, (rocblas_double_complex*)A, lda, ipiv, info);
151 }
152 } // namespace rocsolver
153 
154 } // namespace qmcplusplus
155 #endif // QMCPLUSPLUS_ROCSOLVER_H
rocblas_status getrs(rocblas_handle &handle, const rocblas_operation &transa, int m, int n, double *A, int lda, int *ipiv, double *B, int ldb)
Definition: rocsolver.hpp:117
rocblas_status getrf(rocblas_handle &handle, int m, int n, double *A, int lda, int *ipiv, int *info)
Definition: rocsolver.hpp:101
helper functions for EinsplineSetBuilder
Definition: Configuration.h:43
rocblas_status getri(rocblas_handle &handle, int n, double *A, int lda, int *ipiv, int *info)
Definition: rocsolver.hpp:143
double B(double x, int k, int i, const std::vector< double > &t)
void rocsolverAssert(rocblas_status code, const std::string &cause, const char *file, int line, bool abort=true)
prints rocsolver error messages. Always use rocsolverErrorCheck macro.
Definition: rocsolver.hpp:30