QMCPACK
cusolver.hpp
Go to the documentation of this file.
1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2019 QMCPACK developers.
6 //
7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8 //
9 // File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
10 //////////////////////////////////////////////////////////////////////////////////////
11 
12 #ifndef QMCPLUSPLUS_CUSOLVER_H
13 #define QMCPLUSPLUS_CUSOLVER_H
14 
15 #include <cusolverDn.h>
16 #include <complex>
17 #include <iostream>
18 #include <string>
19 #include <stdexcept>
20 
21 #define cusolverErrorCheck(ans, cause) \
22  { \
23  cusolverAssert((ans), cause, __FILE__, __LINE__); \
24  }
25 /// prints cusolver error messages. Always use cusolverErrorCheck macro.
26 inline void cusolverAssert(cusolverStatus_t code,
27  const std::string& cause,
28  const char* file,
29  int line,
30  bool abort = true)
31 {
32  if (code != CUSOLVER_STATUS_SUCCESS)
33  {
34  std::string cusolver_error;
35  switch (code)
36  {
37  case CUSOLVER_STATUS_NOT_INITIALIZED:
38  cusolver_error = "CUSOLVER_STATUS_NOT_INITIALIZED";
39  break;
40  case CUSOLVER_STATUS_ALLOC_FAILED:
41  cusolver_error = "CUSOLVER_STATUS_ALLOC_FAILED";
42  break;
43  case CUSOLVER_STATUS_INVALID_VALUE:
44  cusolver_error = "CUSOLVER_STATUS_INVALID_VALUE";
45  break;
46  case CUSOLVER_STATUS_ARCH_MISMATCH:
47  cusolver_error = "CUSOLVER_STATUS_ARCH_MISMATCH";
48  break;
49  case CUSOLVER_STATUS_EXECUTION_FAILED:
50  cusolver_error = "CUSOLVER_STATUS_EXECUTION_FAILED";
51  break;
52  case CUSOLVER_STATUS_INTERNAL_ERROR:
53  cusolver_error = "CUSOLVER_STATUS_INTERNAL_ERROR";
54  break;
55  case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
56  cusolver_error = "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
57  break;
58  default:
59  cusolver_error = "<unknown>";
60  }
61 
62  std::ostringstream err;
63  err << "cusolverAssert: " << cusolver_error << ", file " << file << " , line " << line << std::endl
64  << cause << std::endl;
65  std::cerr << err.str();
66  //if (abort) exit(code);
67  throw std::runtime_error(cause);
68  }
69 }
70 
71 namespace qmcplusplus
72 {
73 /** interface to cusolver calls for different data types S/C/D/Z
74  */
75 namespace cusolver
76 {
77 inline cusolverStatus_t getrf_bufferSize(cusolverDnHandle_t& handle, int m, int n, double* A, int lda, int* lwork)
78 {
79  return cusolverDnDgetrf_bufferSize(handle, m, n, A, lda, lwork);
80 }
81 
82 inline cusolverStatus_t getrf_bufferSize(cusolverDnHandle_t& handle,
83  int m,
84  int n,
85  std::complex<double>* A,
86  int lda,
87  int* lwork)
88 {
89  return cusolverDnZgetrf_bufferSize(handle, m, n, (cuDoubleComplex*)A, lda, lwork);
90 }
91 
92 inline cusolverStatus_t getrf(cusolverDnHandle_t& handle,
93  int m,
94  int n,
95  double* A,
96  int lda,
97  double* work,
98  int* ipiv,
99  int* info)
100 {
101  return cusolverDnDgetrf(handle, m, n, A, lda, work, ipiv, info);
102 }
103 
104 inline cusolverStatus_t getrf(cusolverDnHandle_t& handle,
105  int m,
106  int n,
107  std::complex<double>* A,
108  int lda,
109  std::complex<double>* work,
110  int* ipiv,
111  int* info)
112 {
113  return cusolverDnZgetrf(handle, m, n, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)work, ipiv, info);
114 }
115 
116 inline cusolverStatus_t getrs(cusolverDnHandle_t& handle,
117  const cublasOperation_t& transa,
118  int m,
119  int n,
120  const double* A,
121  int lda,
122  int* ipiv,
123  double* B,
124  int ldb,
125  int* info)
126 {
127  return cusolverDnDgetrs(handle, transa, m, n, A, lda, ipiv, B, ldb, info);
128 }
129 
130 inline cusolverStatus_t getrs(cusolverDnHandle_t& handle,
131  const cublasOperation_t& transa,
132  int m,
133  int n,
134  const std::complex<double>* A,
135  int lda,
136  int* ipiv,
137  std::complex<double>* B,
138  int ldb,
139  int* info)
140 {
141  return cusolverDnZgetrs(handle, transa, m, n, (const cuDoubleComplex*)A, lda, ipiv, (cuDoubleComplex*)B, ldb, info);
142 }
143 } // namespace cusolver
144 
145 } // namespace qmcplusplus
146 #endif // QMCPLUSPLUS_CUSOLVER_H
cusolverStatus_t getrs(cusolverDnHandle_t &handle, const cublasOperation_t &transa, int m, int n, const double *A, int lda, int *ipiv, double *B, int ldb, int *info)
Definition: cusolver.hpp:116
helper functions for EinsplineSetBuilder
Definition: Configuration.h:43
#define cuDoubleComplex
Definition: cuda2hip.h:73
#define cublasOperation_t
Definition: cuda2hip.h:41
cusolverStatus_t getrf_bufferSize(cusolverDnHandle_t &handle, int m, int n, double *A, int lda, int *lwork)
Definition: cusolver.hpp:77
double B(double x, int k, int i, const std::vector< double > &t)
void cusolverAssert(cusolverStatus_t code, const std::string &cause, const char *file, int line, bool abort=true)
prints cusolver error messages. Always use cusolverErrorCheck macro.
Definition: cusolver.hpp:26
cusolverStatus_t getrf(cusolverDnHandle_t &handle, int m, int n, double *A, int lda, double *work, int *ipiv, int *info)
Definition: cusolver.hpp:92