24 const T*
const Ainv[],
28 const T*
const phi_vgl_in[],
29 const int phi_vgl_stride,
32 const int batch_count,
33 const std::vector<sycl::event>& dependencies)
35 constexpr
int COLBS = 128;
38 .parallel_for(sycl::nd_range<1>{{
static_cast<size_t>(batch_count * COLBS)}, {
static_cast<size_t>(COLBS)}},
39 dependencies, [=](sycl::nd_item<1> item) {
40 const int iw = item.get_group(0);
41 const T* __restrict__ Ainv_iw = Ainv[iw];
42 T* __restrict__ temp_iw = temp[iw];
43 T* __restrict__ rcopy_iw = rcopy[iw];
44 const T* __restrict__ phi_in_iw = phi_vgl_in[iw];
45 T* __restrict__ dphi_out_iw = dphi_out[iw];
46 T* __restrict__ d2phi_out_iw = d2phi_out[iw];
48 const int tid = item.get_local_id(0);
50 temp_iw[rowchanged] -= T(1);
52 const int num_col_blocks = (
n + COLBS - 1) / COLBS;
53 for (
int ib = 0; ib < num_col_blocks; ib++)
55 const int col_id = ib * COLBS + tid;
58 rcopy_iw[col_id] = Ainv_iw[rowchanged *
lda + col_id];
62 dphi_out_iw[col_id * 3] = phi_in_iw[col_id + phi_vgl_stride];
63 dphi_out_iw[col_id * 3 + 1] = phi_in_iw[col_id + phi_vgl_stride * 2];
64 dphi_out_iw[col_id * 3 + 2] = phi_in_iw[col_id + phi_vgl_stride * 3];
65 d2phi_out_iw[col_id] = phi_in_iw[col_id + phi_vgl_stride * 4];
74 const float*
const Ainv[],
78 const float*
const phi_vgl_in[],
79 const int phi_vgl_stride,
80 float*
const dphi_out[],
81 float*
const d2phi_out[],
82 const int batch_count,
83 const std::vector<sycl::event>& dependencies);
88 const double*
const Ainv[],
91 double*
const rcopy[],
92 const double*
const phi_vgl_in[],
93 const int phi_vgl_stride,
94 double*
const dphi_out[],
95 double*
const d2phi_out[],
96 const int batch_count,
97 const std::vector<sycl::event>& dependencies);
100 const int rowchanged,
102 const std::complex<float>*
const Ainv[],
104 std::complex<float>*
const temp[],
105 std::complex<float>*
const rcopy[],
106 const std::complex<float>*
const phi_vgl_in[],
107 const int phi_vgl_stride,
108 std::complex<float>*
const dphi_out[],
109 std::complex<float>*
const d2phi_out[],
110 const int batch_count,
111 const std::vector<sycl::event>& dependencies);
114 const int rowchanged,
116 const std::complex<double>*
const Ainv[],
118 std::complex<double>*
const temp[],
119 std::complex<double>*
const rcopy[],
120 const std::complex<double>*
const phi_vgl_in[],
121 const int phi_vgl_stride,
122 std::complex<double>*
const dphi_out[],
123 std::complex<double>*
const d2phi_out[],
124 const int batch_count,
125 const std::vector<sycl::event>& dependencies);
127 template<
typename T,
int DIM>
130 const T*
const Ainvrow[],
131 const T*
const dpsiMrow[],
133 const int batch_count,
134 const std::vector<sycl::event>& dependencies)
136 constexpr
int COLBS = 128;
138 return aq.submit([&](sycl::handler& cgh) {
139 cgh.depends_on(dependencies);
141 sycl::local_accessor<T, 1> sum((static_cast<size_t>(
DIM * COLBS)), cgh);
142 cgh.parallel_for(sycl::nd_range<1>{{
static_cast<size_t>(batch_count * COLBS)}, {
static_cast<size_t>(COLBS)}},
143 [=](sycl::nd_item<1> item) {
144 const int iw = item.get_group(0);
145 const T* __restrict__ invRow = Ainvrow[iw];
146 const T* __restrict__ dpsiM_row = dpsiMrow[iw];
148 const int tid = item.get_local_id(0);
149 for (
int idim = 0; idim <
DIM; idim++)
150 sum[idim * COLBS + tid] = T{};
152 const int num_col_blocks = (
n + COLBS - 1) / COLBS;
153 for (
int ib = 0; ib < num_col_blocks; ib++)
155 const int col_id = ib * COLBS + tid;
156 for (
int idim = 0; idim <
DIM; idim++)
158 sum[idim * COLBS + tid] += invRow[col_id] * dpsiM_row[col_id *
DIM + idim];
161 for (
int iend = COLBS / 2; iend > 0; iend /= 2)
163 item.barrier(sycl::access::fence_space::local_space);
164 for (
int idim = 0; idim <
DIM; idim++)
166 sum[idim * COLBS + tid] += sum[idim * COLBS + tid + iend];
170 for (
int idim = 0; idim <
DIM; idim++)
171 grads_now[iw *
DIM + idim] = sum[idim * COLBS];
178 const float*
const Ainvrow[],
179 const float*
const dpsiMrow[],
180 float*
const grads_now,
181 const int batch_count,
182 const std::vector<sycl::event>& dependencies);
186 const double*
const Ainvrow[],
187 const double*
const dpsiMrow[],
188 double*
const grads_now,
189 const int batch_count,
190 const std::vector<sycl::event>& dependencies);
194 const std::complex<float>*
const Ainvrow[],
195 const std::complex<float>*
const dpsiMrow[],
196 std::complex<float>*
const grads_now,
197 const int batch_count,
198 const std::vector<sycl::event>& dependencies);
202 const std::complex<double>*
const Ainvrow[],
203 const std::complex<double>*
const dpsiMrow[],
204 std::complex<double>*
const grads_now,
205 const int batch_count,
206 const std::vector<sycl::event>& dependencies);
211 int*
const delay_list[],
212 const int rowchanged,
213 const int delay_count,
216 const T*
const ratio_inv,
217 const T*
const phi_vgl_in[],
218 const int phi_vgl_stride,
221 T*
const d2phi_out[],
223 const int n_accepted,
224 const int batch_count,
225 const std::vector<sycl::event>& dependencies)
227 constexpr
int COLBS = 64;
229 return aq.parallel_for(sycl::nd_range<1>{{
static_cast<size_t>(batch_count * COLBS)}, {
static_cast<size_t>(COLBS)}},
230 dependencies, [=](sycl::nd_item<1> item) {
231 const int tid = item.get_local_id(0);
232 const int iw = item.get_group(0);
237 int* __restrict__ delay_list_iw = delay_list[iw];
238 T* __restrict__ binvrow_iw = binv[iw] + delay_count * binv_lda;
239 const T* __restrict__ phi_in_iw = phi_vgl_in[iw];
240 T* __restrict__ phi_out_iw = phi_out[iw];
241 T* __restrict__ dphi_out_iw = dphi_out[iw];
242 T* __restrict__ d2phi_out_iw = d2phi_out[iw];
246 delay_list_iw[delay_count] = rowchanged;
247 binvrow_iw[delay_count] = ratio_inv[iw];
250 const int num_delay_count_col_blocks = (delay_count + COLBS - 1) / COLBS;
251 for (
int ib = 0; ib < num_delay_count_col_blocks; ib++)
253 const int col_id = ib * COLBS + tid;
254 if (col_id < delay_count)
255 binvrow_iw[col_id] *= ratio_inv[iw];
258 const int num_col_blocks = (norb + COLBS - 1) / COLBS;
259 for (
int ib = 0; ib < num_col_blocks; ib++)
261 const int col_id = ib * COLBS + tid;
265 phi_out_iw[col_id] = phi_in_iw[col_id];
266 dphi_out_iw[col_id * 3] = phi_in_iw[col_id + phi_vgl_stride];
267 dphi_out_iw[col_id * 3 + 1] = phi_in_iw[col_id + phi_vgl_stride * 2];
268 dphi_out_iw[col_id * 3 + 2] = phi_in_iw[col_id + phi_vgl_stride * 3];
269 d2phi_out_iw[col_id] = phi_in_iw[col_id + phi_vgl_stride * 4];
276 T* __restrict__ Urow_iw = phi_out[iw];
277 const int num_blocks_norb = (norb + COLBS - 1) / COLBS;
278 for (
int ib = 0; ib < num_blocks_norb; ib++)
280 const int col_id = ib * COLBS + tid;
282 Urow_iw[col_id] = T{};
285 T* __restrict__ binv_iw = binv[iw];
286 const int num_blocks_delay_count = (delay_count + COLBS - 1) / COLBS;
287 for (
int ib = 0; ib < num_blocks_delay_count; ib++)
289 const int col_id = ib * COLBS + tid;
290 if (col_id < delay_count)
291 binv_iw[delay_count * binv_lda + col_id] = binv_iw[delay_count + binv_lda * col_id] =
295 int* __restrict__ delay_list_iw = delay_list[iw];
298 binv_iw[delay_count * binv_lda + delay_count] = T(1);
299 delay_list_iw[delay_count] = -1;
306 int*
const delay_list[],
307 const int rowchanged,
308 const int delay_count,
311 const float*
const ratio_inv,
312 const float*
const phi_vgl_in[],
313 const int phi_vgl_stride,
314 float*
const phi_out[],
315 float*
const dphi_out[],
316 float*
const d2phi_out[],
318 const int n_accepted,
319 const int batch_count,
320 const std::vector<sycl::event>& dependencies);
323 int*
const delay_list[],
324 const int rowchanged,
325 const int delay_count,
326 double*
const binv[],
328 const double*
const ratio_inv,
329 const double*
const phi_vgl_in[],
330 const int phi_vgl_stride,
331 double*
const phi_out[],
332 double*
const dphi_out[],
333 double*
const d2phi_out[],
335 const int n_accepted,
336 const int batch_count,
337 const std::vector<sycl::event>& dependencies);
340 int*
const delay_list[],
341 const int rowchanged,
342 const int delay_count,
343 std::complex<float>*
const binv[],
345 const std::complex<float>*
const ratio_inv,
346 const std::complex<float>*
const phi_vgl_in[],
347 const int phi_vgl_stride,
348 std::complex<float>*
const phi_out[],
349 std::complex<float>*
const dphi_out[],
350 std::complex<float>*
const d2phi_out[],
352 const int n_accepted,
353 const int batch_count,
354 const std::vector<sycl::event>& dependencies);
357 int*
const delay_list[],
358 const int rowchanged,
359 const int delay_count,
360 std::complex<double>*
const binv[],
362 const std::complex<double>*
const ratio_inv,
363 const std::complex<double>*
const phi_vgl_in[],
364 const int phi_vgl_stride,
365 std::complex<double>*
const phi_out[],
366 std::complex<double>*
const dphi_out[],
367 std::complex<double>*
const d2phi_out[],
369 const int n_accepted,
370 const int batch_count,
371 const std::vector<sycl::event>& dependencies);
375 const int*
const delay_list[],
376 const int delay_count,
379 const int batch_count,
380 const std::vector<sycl::event>& dependencies)
382 constexpr
int COLBS = 32;
384 return aq.parallel_for(sycl::nd_range<1>{{
static_cast<size_t>(batch_count * COLBS)}, {
static_cast<size_t>(COLBS)}},
385 dependencies, [=](sycl::nd_item<1> item) {
386 const int iw = item.get_group(0);
387 const int* __restrict__ delay_list_iw = delay_list[iw];
388 T* __restrict__ tempMat_iw = tempMat[iw];
391 const int tid = item.get_local_id(0);
392 const int num_blocks = (delay_count + COLBS - 1) / COLBS;
395 const int col_id = ib * COLBS + tid;
396 if (col_id < delay_count)
398 const int row_id = delay_list_iw[col_id];
400 tempMat_iw[row_id *
lda + col_id] +=
408 const int*
const delay_list[],
409 const int delay_count,
410 float*
const tempMat[],
412 const int batch_count,
413 const std::vector<sycl::event>& dependencies);
416 const int*
const delay_list[],
417 const int delay_count,
418 double*
const tempMat[],
420 const int batch_count,
421 const std::vector<sycl::event>& dependencies);
424 const int*
const delay_list[],
425 const int delay_count,
426 std::complex<float>*
const tempMat[],
428 const int batch_count,
429 const std::vector<sycl::event>& dependencies);
432 const int*
const delay_list[],
433 const int delay_count,
434 std::complex<double>*
const tempMat[],
436 const int batch_count,
437 const std::vector<sycl::event>& dependencies);
helper functions for EinsplineSetBuilder
sycl::event calcGradients_batched(sycl::queue &aq, const int n, const T *const Ainvrow[], const T *const dpsiMrow[], T *const grads_now, const int batch_count, const std::vector< sycl::event > &dependencies)
sycl::event add_delay_list_save_sigma_VGL_batched(sycl::queue &aq, int *const delay_list[], const int rowchanged, const int delay_count, T *const binv[], const int binv_lda, const T *const ratio_inv, const T *const phi_vgl_in[], const int phi_vgl_stride, T *const phi_out[], T *const dphi_out[], T *const d2phi_out[], const int norb, const int n_accepted, const int batch_count, const std::vector< sycl::event > &dependencies)
sycl::event applyW_batched(sycl::queue &aq, const int *const delay_list[], const int delay_count, T *const tempMat[], const int lda, const int batch_count, const std::vector< sycl::event > &dependencies)
sycl::event copyAinvRow_saveGL_batched(sycl::queue &aq, const int rowchanged, const int n, const T *const Ainv[], const int lda, T *const temp[], T *const rcopy[], const T *const phi_vgl_in[], const int phi_vgl_stride, T *const dphi_out[], T *const d2phi_out[], const int batch_count, const std::vector< sycl::event > &dependencies)