QMCPACK
hipBLAS.cpp
Go to the documentation of this file.
1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2021 QMCPACK developers.
6 // Copyright(C) 2021 Advanced Micro Devices, Inc. All rights reserved.
7 //
8 // File developed by: Jakub Kurzak, jakurzak@amd.com, Advanced Micro Devices, Inc.
9 // Ye Luo, yeluo@anl.gov, Argonne National Laboratory
10 //
11 // File created by: Jakub Kurzak, jakurzak@amd.com, Advanced Micro Devices, Inc.
12 //////////////////////////////////////////////////////////////////////////////////////
13 
14 
15 #include "hipBLAS.hpp"
16 #include <stdexcept>
17 #include <rocsolver/rocsolver.h>
18 
19 //------------------------------------------------------------------------------
20 hipblasStatus_t hipblasCgemmBatched(hipblasHandle_t handle,
21  hipblasOperation_t transa,
22  hipblasOperation_t transb,
23  int m,
24  int n,
25  int k,
26  const hipComplex* alpha,
27  const hipComplex* const Aarray[],
28  int lda,
29  const hipComplex* const Barray[],
30  int ldb,
31  const hipComplex* beta,
32  hipComplex* const Carray[],
33  int ldc,
34  int batchCount)
35 {
36  return hipblasCgemmBatched(handle, transa, transb, m, n, k, (const hipblasComplex*)alpha,
37  (const hipblasComplex* const*)Aarray, lda, (const hipblasComplex* const*)Barray, ldb,
38  (const hipblasComplex*)beta, (hipblasComplex* const*)Carray, ldc, batchCount);
39 }
40 
41 hipblasStatus_t hipblasZgemmBatched(hipblasHandle_t handle,
42  hipblasOperation_t transa,
43  hipblasOperation_t transb,
44  int m,
45  int n,
46  int k,
47  const hipDoubleComplex* alpha,
48  const hipDoubleComplex* const Aarray[],
49  int lda,
50  const hipDoubleComplex* const Barray[],
51  int ldb,
52  const hipDoubleComplex* beta,
53  hipDoubleComplex* const Carray[],
54  int ldc,
55  int batchCount)
56 {
57  return hipblasZgemmBatched(handle, transa, transb, m, n, k, (const hipblasDoubleComplex*)alpha,
58  (const hipblasDoubleComplex* const*)Aarray, lda,
59  (const hipblasDoubleComplex* const*)Barray, ldb, (const hipblasDoubleComplex*)beta,
60  (hipblasDoubleComplex* const*)Carray, ldc, batchCount);
61 }
62 
63 //------------------------------------------------------------------------------
64 hipblasStatus_t hipblasSgetrfBatched_(hipblasHandle_t handle,
65  int n,
66  float* const A[],
67  int lda,
68  int* P,
69  int* info,
70  int batchSize)
71 {
72  if (!P)
73  throw std::runtime_error("hipblasXgetrfBatched_ pivot array cannot be a null pointer!");
74  return (hipblasStatus_t)rocsolver_sgetrf_batched((rocblas_handle)handle, (const rocblas_int)n, (const rocblas_int)n,
75  (float* const*)A, (const rocblas_int)lda, (rocblas_int*)P,
76  (const rocblas_stride)n, (rocblas_int*)info,
77  (const rocblas_int)batchSize);
78 }
79 
80 hipblasStatus_t hipblasDgetrfBatched_(hipblasHandle_t handle,
81  int n,
82  double* const A[],
83  int lda,
84  int* P,
85  int* info,
86  int batchSize)
87 {
88  if (!P)
89  throw std::runtime_error("hipblasXgetrfBatched_ pivot array cannot be a null pointer!");
90  return (hipblasStatus_t)rocsolver_dgetrf_batched((rocblas_handle)handle, (const rocblas_int)n, (const rocblas_int)n,
91  (double* const*)A, (const rocblas_int)lda, (rocblas_int*)P,
92  (const rocblas_stride)n, (rocblas_int*)info,
93  (const rocblas_int)batchSize);
94 }
95 
96 hipblasStatus_t hipblasCgetrfBatched_(hipblasHandle_t handle,
97  int n,
98  hipComplex* const A[],
99  int lda,
100  int* P,
101  int* info,
102  int batchSize)
103 {
104  if (!P)
105  throw std::runtime_error("hipblasXgetrfBatched_ pivot array cannot be a null pointer!");
106  return (hipblasStatus_t)rocsolver_cgetrf_batched((rocblas_handle)handle, (const rocblas_int)n, (const rocblas_int)n,
107  (rocblas_float_complex* const*)A, (const rocblas_int)lda,
108  (rocblas_int*)P, (const rocblas_stride)n, (rocblas_int*)info,
109  (const rocblas_int)batchSize);
110 }
111 
112 hipblasStatus_t hipblasZgetrfBatched_(hipblasHandle_t handle,
113  int n,
114  hipDoubleComplex* const A[],
115  int lda,
116  int* P,
117  int* info,
118  int batchSize)
119 {
120  if (!P)
121  throw std::runtime_error("hipblasXgetrfBatched_ pivot array cannot be a null pointer!");
122  return (hipblasStatus_t)rocsolver_zgetrf_batched((rocblas_handle)handle, (const rocblas_int)n, (const rocblas_int)n,
123  (rocblas_double_complex* const*)A, (const rocblas_int)lda,
124  (rocblas_int*)P, (const rocblas_stride)n, (rocblas_int*)info,
125  (const rocblas_int)batchSize);
126 }
127 
128 //------------------------------------------------------------------------------
129 hipblasStatus_t hipblasSgetriBatched_(hipblasHandle_t handle,
130  int n,
131  const float* const A[],
132  int lda,
133  const int* P,
134  float* const C[],
135  int ldc,
136  int* info,
137  int batchSize)
138 {
139  if (!P)
140  throw std::runtime_error("hipblasXgetriBatched_ pivot array cannot be a null pointer!");
141  return hipblasSgetriBatched(handle, n, (float* const*)A, lda, (int*)P, (float* const*)C, ldc, info, batchSize);
142 }
143 
144 hipblasStatus_t hipblasDgetriBatched_(hipblasHandle_t handle,
145  int n,
146  const double* const A[],
147  int lda,
148  const int* P,
149  double* const C[],
150  int ldc,
151  int* info,
152  int batchSize)
153 {
154  if (!P)
155  throw std::runtime_error("hipblasXgetriBatched_ pivot array cannot be a null pointer!");
156  return hipblasDgetriBatched(handle, n, (double* const*)A, lda, (int*)P, (double* const*)C, ldc, info, batchSize);
157 }
158 
159 hipblasStatus_t hipblasCgetriBatched_(hipblasHandle_t handle,
160  int n,
161  const hipComplex* const A[],
162  int lda,
163  const int* P,
164  hipComplex* const C[],
165  int ldc,
166  int* info,
167  int batchSize)
168 {
169  if (!P)
170  throw std::runtime_error("hipblasXgetriBatched_ pivot array cannot be a null pointer!");
171  return hipblasCgetriBatched(handle, n, (hipblasComplex* const*)A, lda, (int*)P, (hipblasComplex* const*)C, ldc, info,
172  batchSize);
173 }
174 
175 hipblasStatus_t hipblasZgetriBatched_(hipblasHandle_t handle,
176  int n,
177  const hipDoubleComplex* const A[],
178  int lda,
179  const int* P,
180  hipDoubleComplex* const C[],
181  int ldc,
182  int* info,
183  int batchSize)
184 {
185  if (!P)
186  throw std::runtime_error("hipblasXgetriBatched_ pivot array cannot be a null pointer!");
187  return hipblasZgetriBatched(handle, n, (hipblasDoubleComplex* const*)A, lda, (int*)P, (hipblasDoubleComplex* const*)C,
188  ldc, info, batchSize);
189 }
hipblasStatus_t hipblasSgetriBatched_(hipblasHandle_t handle, int n, const float *const A[], int lda, const int *P, float *const C[], int ldc, int *info, int batchSize)
Definition: hipBLAS.cpp:129
hipblasStatus_t hipblasDgetrfBatched_(hipblasHandle_t handle, int n, double *const A[], int lda, int *P, int *info, int batchSize)
Definition: hipBLAS.cpp:80
hipblasStatus_t hipblasZgetriBatched_(hipblasHandle_t handle, int n, const hipDoubleComplex *const A[], int lda, const int *P, hipDoubleComplex *const C[], int ldc, int *info, int batchSize)
Definition: hipBLAS.cpp:175
hipblasStatus_t hipblasSgetrfBatched_(hipblasHandle_t handle, int n, float *const A[], int lda, int *P, int *info, int batchSize)
Definition: hipBLAS.cpp:64
hipblasStatus_t hipblasCgetriBatched_(hipblasHandle_t handle, int n, const hipComplex *const A[], int lda, const int *P, hipComplex *const C[], int ldc, int *info, int batchSize)
Definition: hipBLAS.cpp:159
hipblasStatus_t hipblasZgemmBatched(hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, int m, int n, int k, const hipDoubleComplex *alpha, const hipDoubleComplex *const Aarray[], int lda, const hipDoubleComplex *const Barray[], int ldb, const hipDoubleComplex *beta, hipDoubleComplex *const Carray[], int ldc, int batchCount)
Definition: hipBLAS.cpp:41
hipblasStatus_t hipblasDgetriBatched_(hipblasHandle_t handle, int n, const double *const A[], int lda, const int *P, double *const C[], int ldc, int *info, int batchSize)
Definition: hipBLAS.cpp:144
hipblasStatus_t hipblasCgetrfBatched_(hipblasHandle_t handle, int n, hipComplex *const A[], int lda, int *P, int *info, int batchSize)
Definition: hipBLAS.cpp:96
hipblasStatus_t hipblasZgetrfBatched_(hipblasHandle_t handle, int n, hipDoubleComplex *const A[], int lda, int *P, int *info, int batchSize)
Definition: hipBLAS.cpp:112
hipblasStatus_t hipblasCgemmBatched(hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, int m, int n, int k, const hipComplex *alpha, const hipComplex *const Aarray[], int lda, const hipComplex *const Barray[], int ldb, const hipComplex *beta, hipComplex *const Carray[], int ldc, int batchCount)
Definition: hipBLAS.cpp:20