Skip to content

Switch to SPIRV APIs from internal built-in APIs #255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: sycl-develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a6c8e53
spirv APIs
jiyang1011 Mar 12, 2025
73bef6e
mma spirv api
jiyang1011 Apr 7, 2025
6e12cb6
Merge branch 'sycl-develop' into jiyang/spirv_api
jiyang1011 Apr 14, 2025
626fd13
Merge branch 'sycl-develop' into jiyang/spirv_api
jiyang1011 Apr 22, 2025
cf6a41b
Merge branch 'sycl-develop' into jiyang/spirv_api
jiyang1011 Apr 29, 2025
d9f8303
remove -1 from OCL API
jiyang1011 Apr 29, 2025
c1cddb6
Merge branch 'sycl-develop' into jiyang/spirv_api
aacostadiaz May 6, 2025
5537fd7
rebase
aacostadiaz May 6, 2025
c89a875
Disable spirv functions for PVC
aacostadiaz May 6, 2025
5e26dd3
move spirv definitions
aacostadiaz May 6, 2025
8c67947
fix
aacostadiaz May 6, 2025
1af7011
Merge branch 'sycl-develop' into jiyang/spirv_api
aacostadiaz May 6, 2025
879eb35
Refactor
aacostadiaz May 8, 2025
9864ab2
Fix cmake
aacostadiaz May 8, 2025
39e549d
Re-enable test
aacostadiaz May 8, 2025
d6c9358
Fix mma builtin
aacostadiaz May 8, 2025
ec9d0a7
Fix copy builtin
aacostadiaz May 8, 2025
7144422
Revert minor changes
aacostadiaz May 9, 2025
3d30536
Merge branch 'sycl-develop' into jiyang/spirv_api
aacostadiaz May 12, 2025
4bbaaa6
Use builtin for prefetch
aacostadiaz May 12, 2025
304de17
Remove FP16 MMA with FP16 accumulator
aacostadiaz May 13, 2025
83b62f8
Update include/cute/arch/copy_xe_spirv.hpp
aacostadiaz May 15, 2025
96c7c80
Update CMakeLists.txt
aacostadiaz May 15, 2025
1dc13ea
Address comments
aacostadiaz May 15, 2025
eef75b9
Merge remote-tracking branch 'taozha2/jiyang/spirv_api' into jiyang/s…
aacostadiaz May 15, 2025
9308727
Merge branch 'sycl-develop' into jiyang/spirv_api
aacostadiaz May 15, 2025
0ee4d9f
Merge branch 'sycl-develop' into jiyang/spirv_api
aacostadiaz May 15, 2025
8774320
Use round
aacostadiaz May 20, 2025
f50200b
Revert changes in cute tests
aacostadiaz May 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ option(CUTLASS_SYCL_PROFILING_ENABLED "Use SYCL events to calculate device execu
option(CUTLASS_SYCL_RUNNING_CI "Enable this option when building in a CI environment.
It activates CI specific configurations, such as additional checks or selectively
disabling tests that cannot run in CI." OFF)
option(CUTLASS_SYCL_BUILTIN_ENABLE "Enable this option to use builtin functions instead of SPIR-V for Block Copy & MMA operations" OFF)

if (CUTLASS_ENABLE_SYCL)
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)

Expand All @@ -130,6 +132,10 @@ if (CUTLASS_ENABLE_SYCL)
add_compile_definitions(SYCLCOMPAT_PROFILING_ENABLED)
endif()

if (CUTLASS_SYCL_BUILTIN_ENABLE)
add_compile_definitions(CUTLASS_SYCL_BUILTIN_ENABLE)
endif()

include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/onemkl.cmake)
endif()
find_package(Doxygen QUIET)
Expand Down
7 changes: 5 additions & 2 deletions cmake/FindDPCPP.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@ endif()
if("${DPCPP_SYCL_TARGET}" STREQUAL "intel_gpu_pvc" OR
"${DPCPP_SYCL_TARGET}" STREQUAL "spir64" OR
"${DPCPP_SYCL_TARGET}" STREQUAL "intel_gpu_bmg_g21")
list(APPEND DPCPP_FLAGS "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier")

if ((CMAKE_CXX_COMPILER_ID MATCHES "IntelLLVM" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 2025.2) OR CUTLASS_SYCL_BUILTIN_ENABLE)
list(APPEND DPCPP_FLAGS "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier")
else()
list(APPEND DPCPP_FLAGS "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate")
endif()
if(DPCPP_DISABLE_ITT_FOR_CUTLASS)
list(APPEND DPCPP_FLAGS "-fno-sycl-instrument-device-code")
endif()
Expand Down
100 changes: 14 additions & 86 deletions include/cute/arch/copy_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,48 +29,22 @@
*
**************************************************************************************************/
#pragma once
#include <cute/arch/xe_copy_1B.hpp>
#include <cute/arch/xe_copy_2B.hpp>
#include <cute/arch/xe_copy_4B.hpp>
#include <cute/arch/xe_copy_8B.hpp>
#ifdef __SYCL_DEVICE_ONLY__
#define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x
#else
#define SYCL_DEVICE_BUILTIN(x) inline x { assert(false); }

#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_INTEL_TARGET)
#define CUTE_ARCH_COPY_XE_ENABLED
#endif

#if defined(CUTE_ARCH_COPY_XE_ENABLED) && ((defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER < 20250200)) || defined(CUTLASS_SYCL_BUILTIN_ENABLE))
#include <cute/arch/copy_xe_builtin.hpp>
#elif defined(CUTE_ARCH_COPY_XE_ENABLED)
#include <cute/arch/copy_xe_spirv.hpp>
#endif

// prefetch
SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_uchar(
const __attribute__((opencl_global)) uint8_t *base, int immElemOff,
enum CacheControl cacheOpt));
SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_ushort(
const __attribute__((opencl_global)) uint16_t *base, int immElemOff,
enum CacheControl cacheOpt));
SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_uint(
const __attribute__((opencl_global)) uint32_t *base, int immElemOff,
enum CacheControl cacheOpt));
SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_uint2(
const __attribute__((opencl_global)) uint32_t *base, int immElemOff,
enum CacheControl cacheOpt));
SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_uint4(
const __attribute__((opencl_global)) uint32_t *base, int immElemOff,
enum CacheControl cacheOpt));
SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_uint8(
const __attribute__((opencl_global)) uint32_t *base, int immElemOff,
enum CacheControl cacheOpt));
SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_ulong(
const __attribute__((opencl_global)) uint64_t *base, int immElemOff,
enum CacheControl cacheOpt));
SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_ulong2(
const __attribute__((opencl_global)) uint64_t *base, int immElemOff,
enum CacheControl cacheOpt));
SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_ulong4(
const __attribute__((opencl_global)) uint64_t *base, int immElemOff,
enum CacheControl cacheOpt));
SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_ulong8(
const __attribute__((opencl_global)) uint64_t *base, int immElemOff,
enum CacheControl cacheOpt));
#undef SYCL_DEVICE_BUILTIN
#include <cute/arch/copy_xe_U4.hpp>
#include <cute/arch/copy_xe_U8.hpp>
#include <cute/arch/copy_xe_U16.hpp>
#include <cute/arch/copy_xe_U32.hpp>
#include <cute/arch/copy_xe_U64.hpp>

#ifdef __SYCL_DEVICE_ONLY__
SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope, int memory_semantics);
Expand Down Expand Up @@ -142,49 +116,6 @@ struct XE_1D_LDSM {
}
};

template <class S, class D = S>
struct PREFETCH {
using SRegisters = S[1];
using DRegisters = D[1];

template <class S_, class D_>
CUTE_HOST_DEVICE static void copy(const S_ &src, D_ &dst) {
#if defined(SYCL_INTEL_TARGET)
if constexpr(sizeof(D) == 1) {
__builtin_IB_lsc_prefetch_global_uchar(
(const __attribute__((opencl_global)) uint8_t *)(&*&src), 0, CacheControl::kL1C_L3C);
}
else if constexpr(sizeof(D) == 2) {
__builtin_IB_lsc_prefetch_global_ushort(
(const __attribute__((opencl_global)) uint16_t *)(&*&src), 0, CacheControl::kL1C_L3C);
}
else if constexpr(sizeof(D) == 4) {
__builtin_IB_lsc_prefetch_global_uint(
(const __attribute__((opencl_global)) uint32_t *)(&*&src), 0, CacheControl::kL1C_L3C);
}
else if constexpr(sizeof(D) == 8) {
__builtin_IB_lsc_prefetch_global_uint2(
(const __attribute__((opencl_global)) uint32_t *)(&*&src), 0, CacheControl::kL1C_L3C);
}
else if constexpr(sizeof(D) == 16) {
__builtin_IB_lsc_prefetch_global_uint4(
(const __attribute__((opencl_global)) uint32_t *)(&*&src), 0, CacheControl::kL1C_L3C);
}
else if constexpr(sizeof(D) == 32) {
__builtin_IB_lsc_prefetch_global_uint8(
(const __attribute__((opencl_global)) uint32_t *)(&*&src), 0, CacheControl::kL1C_L3C);
}
else if constexpr(sizeof(D) == 64) {
__builtin_IB_lsc_prefetch_global_ulong8(
(const __attribute__((opencl_global)) uint64_t *)(&*&src), 0, CacheControl::kL1C_L3C);
}
#else
CUTE_INVALID_CONTROL_PATH(
"Trying to use block prefetch on non-Xe hardware");
#endif
}
};

template <class S, class D = S>
struct XE_1D_LOAD_GLOBAL {
using SRegisters = S[1];
Expand Down Expand Up @@ -212,9 +143,6 @@ struct XE_1D_LOAD_GLOBAL {
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-Xe hardware");
#endif
}

using PREFETCH = PREFETCH<S, D>;

};

template<class S, class D = S>
Expand Down
Loading
Loading