diff --git a/sycl/source/detail/device_impl.cpp b/sycl/source/detail/device_impl.cpp index f2f2673562f82..7d279786157ca 100644 --- a/sycl/source/detail/device_impl.cpp +++ b/sycl/source/detail/device_impl.cpp @@ -21,7 +21,8 @@ namespace detail { /// Constructs a SYCL device instance using the provided /// UR device instance. -device_impl::device_impl(ur_device_handle_t Device, platform_impl &Platform) +device_impl::device_impl(ur_device_handle_t Device, platform_impl &Platform, + device_impl::private_tag) : MDevice(Device), MPlatform(Platform.shared_from_this()), MDeviceHostBaseTime(std::make_pair(0, 0)) { const AdapterPtr &Adapter = Platform.getAdapter(); diff --git a/sycl/source/detail/device_impl.hpp b/sycl/source/detail/device_impl.hpp index 2b678fe475f31..48957e935107e 100644 --- a/sycl/source/detail/device_impl.hpp +++ b/sycl/source/detail/device_impl.hpp @@ -33,10 +33,19 @@ class platform_impl; // TODO: Make code thread-safe class device_impl { + struct private_tag { + explicit private_tag() = default; + }; + friend class platform_impl; + public: /// Constructs a SYCL device instance using the provided /// UR device instance. - explicit device_impl(ur_device_handle_t Device, platform_impl &Platform); + // + // Must be called through `platform_impl::getOrMakeDeviceImpl` only. + // `private_tag` ensures that is true. + explicit device_impl(ur_device_handle_t Device, platform_impl &Platform, + private_tag); ~device_impl(); diff --git a/sycl/source/detail/platform_impl.cpp b/sycl/source/detail/platform_impl.cpp index 8fefa49480134..3bf34a90492dc 100644 --- a/sycl/source/detail/platform_impl.cpp +++ b/sycl/source/detail/platform_impl.cpp @@ -304,7 +304,8 @@ platform_impl::getOrMakeDeviceImpl(ur_device_handle_t UrDevice) { return Result; // Otherwise make the impl - Result = std::make_shared(UrDevice, *this); + Result = std::make_shared(UrDevice, *this, + device_impl::private_tag{}); MDeviceCache.emplace_back(Result); return Result; diff --git a/sycl/unittests/program_manager/SubDevices.cpp b/sycl/unittests/program_manager/SubDevices.cpp index b65eb6e12cb8f..39163a15c8f91 100644 --- a/sycl/unittests/program_manager/SubDevices.cpp +++ b/sycl/unittests/program_manager/SubDevices.cpp @@ -106,10 +106,8 @@ TEST(SubDevices, DISABLED_BuildProgramForSubdevices) { rootDevice = sycl::detail::getSyclObjImpl(device)->getHandleRef(); // Initialize sub-devices sycl::detail::platform_impl &PltImpl = *sycl::detail::getSyclObjImpl(Plt); - auto subDev1 = - std::make_shared(urSubDev1, PltImpl); - auto subDev2 = - std::make_shared(urSubDev2, PltImpl); + auto subDev1 = PltImpl.getOrMakeDeviceImpl(urSubDev1); + auto subDev2 = PltImpl.getOrMakeDeviceImpl(urSubDev2); sycl::context Ctx{ {device, sycl::detail::createSyclObjFromImpl(subDev1), sycl::detail::createSyclObjFromImpl(subDev2)}}; diff --git a/sycl/unittests/queue/DeviceCheck.cpp b/sycl/unittests/queue/DeviceCheck.cpp index 09e8be76e064c..49ff4fd64f79e 100644 --- a/sycl/unittests/queue/DeviceCheck.cpp +++ b/sycl/unittests/queue/DeviceCheck.cpp @@ -62,6 +62,18 @@ ur_result_t redefinedDevicePartitionAfter(void *pParams) { **params.ppNumDevicesRet = *params.pNumDevices; return UR_RESULT_SUCCESS; } +ur_result_t redefinedPlatformGet(void *pParams) { + auto params = reinterpret_cast(pParams); + if (*params->ppNumPlatforms) + **params->ppNumPlatforms = 2; + + if (*params->pphPlatforms && *params->pNumEntries > 0) { + (*params->pphPlatforms)[0] = reinterpret_cast(1); + (*params->pphPlatforms)[1] = reinterpret_cast(2); + } + + return UR_RESULT_SUCCESS; +} // Check that the device is verified to be either a member of the context or a // descendant of its member. @@ -71,6 +83,8 @@ TEST(QueueDeviceCheck, CheckDeviceRestriction) { detail::SYCLConfig::reset); sycl::unittest::UrMock<> Mock; + mock::getCallbacks().set_replace_callback("urPlatformGet", + &redefinedPlatformGet); sycl::platform Plt = sycl::platform(); UrPlatform = detail::getSyclObjImpl(Plt)->getHandleRef(); @@ -116,12 +130,15 @@ TEST(QueueDeviceCheck, CheckDeviceRestriction) { // Device is neither of the two. { ParentDevice = nullptr; - device Device = detail::createSyclObjFromImpl( - std::make_shared( - reinterpret_cast(0x01), - *detail::getSyclObjImpl(Plt))); + + auto Plts = sycl::platform::get_platforms(); + EXPECT_TRUE(Plts.size() == 2); + sycl::platform OtherPlt = Plts[1]; + + device Device = OtherPlt.get_devices()[0]; queue Q{Device}; - EXPECT_NE(Q.get_context(), DefaultCtx); + auto Ctx = Q.get_context(); + EXPECT_NE(Ctx, DefaultCtx); try { queue Q2{DefaultCtx, Device}; EXPECT_TRUE(false);