Skip to content

[SYCL] Align sycl_ext_oneapi_address_cast impl with the spec #15402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ implementation supports.
namespace sycl::ext::oneapi::experimental {

// Shorthands for address space names
constexpr inline address_space global_space = sycl::access::address_space::global_space;
constexpr inline address_space local_space = sycl::access::address_space::local_space;
constexpr inline address_space private_space = sycl::access::address_space::private_space;
constexpr inline address_space generic_space = sycl::access::address_space::generic_space;
constexpr inline access::address_space global_space = access::address_space::global_space;
constexpr inline access::address_space local_space = access::address_space::local_space;
constexpr inline access::address_space private_space = access::address_space::private_space;
constexpr inline access::address_space generic_space = access::address_space::generic_space;

template <access::address_space Space,
typename ElementType>
Expand Down
49 changes: 37 additions & 12 deletions sycl/include/sycl/ext/oneapi/experimental/address_cast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,49 +16,74 @@ inline namespace _V1 {
namespace ext {
namespace oneapi {
namespace experimental {
// Shorthands for address space names
constexpr inline access::address_space global_space = access::address_space::global_space;
constexpr inline access::address_space local_space = access::address_space::local_space;
constexpr inline access::address_space private_space = access::address_space::private_space;
constexpr inline access::address_space generic_space = access::address_space::generic_space;

template <access::address_space Space, access::decorated DecorateAddress,
typename ElementType>
multi_ptr<ElementType, Space, DecorateAddress>
template <access::address_space Space, typename ElementType>
multi_ptr<ElementType, Space, access::decorated::no>
static_address_cast(ElementType *Ptr) {
using ret_ty = multi_ptr<ElementType, Space, access::decorated::no>;
#ifdef __SYCL_DEVICE_ONLY__
// TODO: Remove this restriction.
static_assert(std::is_same_v<ElementType, remove_decoration_t<ElementType>>,
"The extension expect undecorated raw pointers only!");
if constexpr (Space == access::address_space::generic_space) {
if constexpr (Space == generic_space) {
// Undecorated raw pointer is in generic AS already, no extra casts needed.
// Note for future, for `OpPtrCastToGeneric`, `Pointer` must point to one of
// `Storage Classes` that doesn't include `Generic`, so this will have to
// remain a special case even if the restriction above is lifted.
return multi_ptr<ElementType, Space, DecorateAddress>(Ptr);
return ret_ty(Ptr);
} else {
auto CastPtr = sycl::detail::spirv::GenericCastToPtr<Space>(Ptr);
return multi_ptr<ElementType, Space, DecorateAddress>(CastPtr);
return ret_ty(CastPtr);
}
#else
return multi_ptr<ElementType, Space, DecorateAddress>(Ptr);
return ret_ty(Ptr);
#endif
}

template <access::address_space Space, access::decorated DecorateAddress,
typename ElementType>
multi_ptr<ElementType, Space, DecorateAddress>
multi_ptr<ElementType, Space, DecorateAddress> static_address_cast(
multi_ptr<ElementType, generic_space, DecorateAddress> Ptr) {
if constexpr (Space == generic_space)
return Ptr;
else
return {static_address_cast<Space>(Ptr.get_raw())};
}

template <access::address_space Space, typename ElementType>
multi_ptr<ElementType, Space, access::decorated::no>
dynamic_address_cast(ElementType *Ptr) {
using ret_ty = multi_ptr<ElementType, Space, access::decorated::no>;
#ifdef __SYCL_DEVICE_ONLY__
// TODO: Remove this restriction.
static_assert(std::is_same_v<ElementType, remove_decoration_t<ElementType>>,
"The extension expect undecorated raw pointers only!");
if constexpr (Space == access::address_space::generic_space) {
return multi_ptr<ElementType, Space, DecorateAddress>(Ptr);
if constexpr (Space == generic_space) {
return ret_ty(Ptr);
} else {
auto CastPtr = sycl::detail::spirv::GenericCastToPtrExplicit<Space>(Ptr);
return multi_ptr<ElementType, Space, DecorateAddress>(CastPtr);
return ret_ty(CastPtr);
}
#else
return multi_ptr<ElementType, Space, DecorateAddress>(Ptr);
return ret_ty(Ptr);
#endif
}

template <access::address_space Space, access::decorated DecorateAddress,
typename ElementType>
multi_ptr<ElementType, Space, DecorateAddress> dynamic_address_cast(
multi_ptr<ElementType, generic_space, DecorateAddress> Ptr) {
if constexpr (Space == generic_space)
return Ptr;
else
return {dynamic_address_cast<Space>(Ptr.get_raw())};
}

} // namespace experimental
} // namespace oneapi
} // namespace ext
Expand Down
36 changes: 18 additions & 18 deletions sycl/test-e2e/AddressCast/dynamic_address_cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ int main() {
{
auto GlobalPointer =
sycl::ext::oneapi::experimental::dynamic_address_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::no>(RawGlobalPointer);
sycl::access::address_space::global_space>(
RawGlobalPointer);
auto LocalPointer =
sycl::ext::oneapi::experimental::dynamic_address_cast<
sycl::access::address_space::local_space,
sycl::access::decorated::no>(RawGlobalPointer);
sycl::access::address_space::local_space>(
RawGlobalPointer);
auto PrivatePointer =
sycl::ext::oneapi::experimental::dynamic_address_cast<
sycl::access::address_space::private_space,
sycl::access::decorated::no>(RawGlobalPointer);
sycl::access::address_space::private_space>(
RawGlobalPointer);
Success &= reinterpret_cast<size_t>(RawGlobalPointer) ==
reinterpret_cast<size_t>(GlobalPointer.get_raw());
Success &= LocalPointer.get_raw() == nullptr;
Expand All @@ -62,16 +62,16 @@ int main() {
{
auto GlobalPointer =
sycl::ext::oneapi::experimental::dynamic_address_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::no>(RawLocalPointer);
sycl::access::address_space::global_space>(
RawLocalPointer);
auto LocalPointer =
sycl::ext::oneapi::experimental::dynamic_address_cast<
sycl::access::address_space::local_space,
sycl::access::decorated::no>(RawLocalPointer);
sycl::access::address_space::local_space>(
RawLocalPointer);
auto PrivatePointer =
sycl::ext::oneapi::experimental::dynamic_address_cast<
sycl::access::address_space::private_space,
sycl::access::decorated::no>(RawLocalPointer);
sycl::access::address_space::private_space>(
RawLocalPointer);
Success &= GlobalPointer.get_raw() == nullptr;
Success &= reinterpret_cast<size_t>(RawLocalPointer) ==
reinterpret_cast<size_t>(LocalPointer.get_raw());
Expand All @@ -83,16 +83,16 @@ int main() {
{
auto GlobalPointer =
sycl::ext::oneapi::experimental::dynamic_address_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::no>(RawPrivatePointer);
sycl::access::address_space::global_space>(
RawPrivatePointer);
auto LocalPointer =
sycl::ext::oneapi::experimental::dynamic_address_cast<
sycl::access::address_space::local_space,
sycl::access::decorated::no>(RawPrivatePointer);
sycl::access::address_space::local_space>(
RawPrivatePointer);
auto PrivatePointer =
sycl::ext::oneapi::experimental::dynamic_address_cast<
sycl::access::address_space::private_space,
sycl::access::decorated::no>(RawPrivatePointer);
sycl::access::address_space::private_space>(
RawPrivatePointer);
Success &= GlobalPointer.get_raw() == nullptr;
Success &= LocalPointer.get_raw() == nullptr;
Success &= reinterpret_cast<size_t>(RawPrivatePointer) ==
Expand Down
12 changes: 6 additions & 6 deletions sycl/test-e2e/AddressCast/static_address_cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,25 @@ int main() {
int *RawGlobalPointer = &GlobalAccessor[Index];
auto GlobalPointer =
sycl::ext::oneapi::experimental::static_address_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::no>(RawGlobalPointer);
sycl::access::address_space::global_space>(
RawGlobalPointer);
Success &= reinterpret_cast<size_t>(RawGlobalPointer) ==
reinterpret_cast<size_t>(GlobalPointer.get_raw());

int *RawLocalPointer = &LocalAccessor[0];
auto LocalPointer =
sycl::ext::oneapi::experimental::static_address_cast<
sycl::access::address_space::local_space,
sycl::access::decorated::no>(RawLocalPointer);
sycl::access::address_space::local_space>(
RawLocalPointer);
Success &= reinterpret_cast<size_t>(RawLocalPointer) ==
reinterpret_cast<size_t>(LocalPointer.get_raw());

int PrivateVariable = 0;
int *RawPrivatePointer = &PrivateVariable;
auto PrivatePointer =
sycl::ext::oneapi::experimental::static_address_cast<
sycl::access::address_space::private_space,
sycl::access::decorated::no>(RawPrivatePointer);
sycl::access::address_space::private_space>(
RawPrivatePointer);
Success &= reinterpret_cast<size_t>(RawPrivatePointer) ==
reinterpret_cast<size_t>(PrivatePointer.get_raw());

Expand Down
Loading
Loading