diff --git a/examples/cute/tutorial/sgemm_1_sycl.cpp b/examples/cute/tutorial/sgemm_1_sycl.cpp index ff4e2a2365..b9caf465b4 100644 --- a/examples/cute/tutorial/sgemm_1_sycl.cpp +++ b/examples/cute/tutorial/sgemm_1_sycl.cpp @@ -39,6 +39,26 @@ #include "cutlass/util/sycl_event_manager.hpp" #include "cutlass/util/GPU_Clock.hpp" +namespace syclcompat { + template + sycl::event launch(const sycl::nd_range &range, sycl::queue q, const F& f) { + return q.parallel_for(detail::transform_nd_range(range), [=](sycl::nd_item) { f(); }); + } + template + sycl::event launch(const sycl::nd_range &range, const F& f) { + return launch(range, get_default_queue(), f); + } + // Alternative launch through dim3 objects + template + sycl::event launch(const dim3 &grid, const dim3 &threads, sycl::queue q, const F& f) { + return launch(sycl::nd_range<3>{grid * threads, threads}, q, f); + } + template + sycl::event launch(const dim3 &grid, const dim3 &threads, const F& f) { + return launch(grid, threads, get_default_queue(), f); + } +} + template >(dimGrid, dimBlock, prob_shape, cta_tiler, - A, dA, sA, tA, - B, dB, sB, tB, - C, dC, sC, tC, - alpha, beta); + auto event = syclcompat::launch(dimGrid, dimBlock, [=] + { gemm_device(prob_shape, cta_tiler, + A, dA, sA, tA, + B, dB, sB, tB, + C, dC, sC, tC, + alpha, beta); }); EventManager::getInstance().addEvent(event); }