QMCPACK
CUDATypeMapping.hpp
Go to the documentation of this file.
1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2021 QMCPACK developers.
6 //
7 // File developed by: Peter Doak, doakpw@ornl.gov, Oak Ridge National Laboratory
8 //
9 // File created by: Peter Doak, doakpw@ornl.gov, Oak Ridge National Laboratory
10 //////////////////////////////////////////////////////////////////////////////////////
11 
12 
13 #ifndef QMCPLUSPLUS_CUDA_TYPE_MAPPING_HPP
14 #define QMCPLUSPLUS_CUDA_TYPE_MAPPING_HPP
15 
16 #include <type_traits>
18 #include "config.h"
19 #ifndef QMC_CUDA2HIP
20 #include <cuComplex.h>
21 #else
22 #include <hip/hip_complex.h>
24 #endif
25 
26 namespace qmcplusplus
27 {
28 
29 // This saves us writing specific overloads with reinterpret casts for different std::complex to cuComplex types.
30 template<typename T>
31 using CUDATypeMap =
32  typename std::disjunction<OnTypesEqual<T, float, float>,
33  OnTypesEqual<T, double, double>,
34  OnTypesEqual<T, float*, float*>,
35  OnTypesEqual<T, double*, double*>,
36  OnTypesEqual<T, float**, float**>,
37  OnTypesEqual<T, double**, double**>,
38  OnTypesEqual<T, std::complex<double>, cuDoubleComplex>,
39  OnTypesEqual<T, std::complex<float>, cuComplex>,
40  OnTypesEqual<T, std::complex<double>*, cuDoubleComplex*>,
41  OnTypesEqual<T, std::complex<float>**, cuComplex**>,
42  OnTypesEqual<T, std::complex<double>**, cuDoubleComplex**>,
43  OnTypesEqual<T, std::complex<float>*, cuComplex*>,
44  OnTypesEqual<T, const std::complex<double>*, const cuDoubleComplex*>,
45  OnTypesEqual<T, const std::complex<float>*, const cuComplex*>,
46  OnTypesEqual<T, const std::complex<float>**, const cuComplex**>,
47  OnTypesEqual<T, const std::complex<double>**, const cuDoubleComplex**>,
48  OnTypesEqual<T, const std::complex<float>* const*, const cuComplex* const*>,
49  OnTypesEqual<T, const std::complex<double>* const*, const cuDoubleComplex* const*>,
51 
52 template<typename T>
54 {
55  return reinterpret_cast<CUDATypeMap<T>>(var);
56 }
57 
58 } // namespace qmcplusplus
59 
60 #endif // QMCPLUSPLUS_CUDA_TYPE_MAPPING_HPP
helper functions for EinsplineSetBuilder
Definition: Configuration.h:43
typename std::disjunction< OnTypesEqual< T, float, float >, OnTypesEqual< T, double, double >, OnTypesEqual< T, float *, float * >, OnTypesEqual< T, double *, double * >, OnTypesEqual< T, float **, float ** >, OnTypesEqual< T, double **, double ** >, OnTypesEqual< T, std::complex< double >, cuDoubleComplex >, OnTypesEqual< T, std::complex< float >, cuComplex >, OnTypesEqual< T, std::complex< double > *, cuDoubleComplex * >, OnTypesEqual< T, std::complex< float > **, cuComplex ** >, OnTypesEqual< T, std::complex< double > **, cuDoubleComplex ** >, OnTypesEqual< T, std::complex< float > *, cuComplex * >, OnTypesEqual< T, const std::complex< double > *, const cuDoubleComplex * >, OnTypesEqual< T, const std::complex< float > *, const cuComplex * >, OnTypesEqual< T, const std::complex< float > **, const cuComplex ** >, OnTypesEqual< T, const std::complex< double > **, const cuDoubleComplex ** >, OnTypesEqual< T, const std::complex< float > *const *, const cuComplex *const * >, OnTypesEqual< T, const std::complex< double > *const *, const cuDoubleComplex *const * >, default_type< void > >::type CUDATypeMap
#define cuComplex
Definition: cuda2hip.h:72
#define cuDoubleComplex
Definition: cuda2hip.h:73
CUDATypeMap< T > castCUDAType(T var)