Skip to content

Commit 8b796d1

Browse files
committed
[rocsolver] Use enqueue_native_command ext
This makes use of the enqueue_native_command dpc++ extension if it is available. This improves performance and integrates correctly with the dpc++ scheduler. Signed-off-by: JackAKirk <[email protected]>
1 parent b2324f1 commit 8b796d1

File tree

4 files changed

+67
-52
lines changed

4 files changed

+67
-52
lines changed

src/lapack/backends/rocsolver/rocsolver_batch.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu
527527
for (int64_t i = 0; i < group_count; i++) {
528528
auto **a_ = reinterpret_cast<rocmDataType **>(a_dev);
529529
auto *info_ = reinterpret_cast<rocblas_int *>(info);
530-
ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo[i]),
530+
rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo[i]),
531531
(int)n[i], a_ + offset, (int)lda[i], info_ + offset,
532532
(int)group_sizes[i]);
533533
offset += group_sizes[i];
@@ -627,7 +627,7 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu
627627
for (int64_t i = 0; i < group_count; i++) {
628628
auto **a_ = reinterpret_cast<rocmDataType **>(a_dev);
629629
auto **b_ = reinterpret_cast<rocmDataType **>(b_dev);
630-
ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo[i]),
630+
rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo[i]),
631631
(int)n[i], (int)nrhs[i], a_ + offset, (int)lda[i],
632632
b_ + offset, (int)ldb[i], (int)group_sizes[i]);
633633
offset += group_sizes[i];

src/lapack/backends/rocsolver/rocsolver_helper.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,17 @@ class hip_error : virtual public std::runtime_error {
166166
hipError_t hip_err; \
167167
HIP_ERROR_FUNC(hipStreamSynchronize, hip_err, currentStreamId);
168168

169+
template <class Func, class... Types>
170+
inline void rocsolver_native_named_func(const char *func_name, Func func,
171+
rocsolver_status err,
172+
rocsolver_handle handle, Types... args){
173+
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
174+
ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, args...)
175+
#else
176+
ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, args...)
177+
#endif
178+
};
179+
169180
inline rocblas_eform get_rocsolver_itype(std::int64_t itype) {
170181
switch (itype) {
171182
case 1: return rocblas_eform_ax;

0 commit comments

Comments
 (0)