Skip to content

Commit b143934

Browse files
WIP
1 parent 9903099 commit b143934

21 files changed

+92
-79
lines changed

sycl/include/sycl/interop_handle.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,12 @@ class interop_handle {
205205
friend class detail::DispatchHostTask;
206206
using ReqToMem = std::pair<detail::AccessorImplHost *, ur_mem_handle_t>;
207207

208+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
209+
// Clean this up (no shared pointers). Not doing it right now because I expect
210+
// there will be several iterations of simplifications possible and it would
211+
// be hard to track which of them made their way into a minor public release
212+
// and which didn't. Let's just clean it up once during ABI breaking window.
213+
#endif
208214
interop_handle(std::vector<ReqToMem> MemObjs,
209215
const std::shared_ptr<detail::queue_impl> &Queue,
210216
const std::shared_ptr<detail::device_impl> &Device,

sycl/source/detail/async_alloc.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void *async_malloc(sycl::handler &h, sycl::usm::alloc kind, size_t size) {
6767
sycl::make_error_code(sycl::errc::feature_not_supported),
6868
"Only device backed asynchronous allocations are supported!");
6969

70-
auto &Adapter = h.getContextImplPtr()->getAdapter();
70+
auto &Adapter = detail::getSyclObjImpl(h)->get_context().getAdapter();
7171

7272
// Get CG event dependencies for this allocation.
7373
const auto &DepEvents = h.impl->CGData.MEvents;
@@ -117,7 +117,7 @@ __SYCL_EXPORT void *async_malloc(const sycl::queue &q, sycl::usm::alloc kind,
117117
__SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
118118
const memory_pool &pool) {
119119

120-
auto &Adapter = h.getContextImplPtr()->getAdapter();
120+
auto &Adapter = detail::getSyclObjImpl(h)->get_context().getAdapter();
121121
auto &memPoolImpl = sycl::detail::getSyclObjImpl(pool);
122122

123123
// Get CG event dependencies for this allocation.

sycl/source/detail/backend_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ inline namespace _V1 {
1515
namespace detail {
1616

1717
template <class T> backend getImplBackend(const T &Impl) {
18-
return Impl->getContextImplPtr()->getBackend();
18+
return Impl->getContextImpl().getBackend();
1919
}
2020

2121
} // namespace detail

sycl/source/detail/device_image_impl.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -570,13 +570,13 @@ class device_image_impl {
570570

571571
ur_native_handle_t getNative() const {
572572
assert(MProgram);
573-
const auto &ContextImplPtr = detail::getSyclObjImpl(MContext);
574-
const AdapterPtr &Adapter = ContextImplPtr->getAdapter();
573+
context_impl &ContextImpl = *detail::getSyclObjImpl(MContext);
574+
const AdapterPtr &Adapter = ContextImpl.getAdapter();
575575

576576
ur_native_handle_t NativeProgram = 0;
577577
Adapter->call<UrApiKind::urProgramGetNativeHandle>(MProgram,
578578
&NativeProgram);
579-
if (ContextImplPtr->getBackend() == backend::opencl)
579+
if (ContextImpl.getBackend() == backend::opencl)
580580
__SYCL_OCL_CALL(clRetainProgram, ur::cast<cl_program>(NativeProgram));
581581

582582
return NativeProgram;

sycl/source/detail/event_impl.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,10 @@ void event_impl::setHandle(const ur_event_handle_t &UREvent) {
140140
MEvent.store(UREvent);
141141
}
142142

143-
const ContextImplPtr &event_impl::getContextImpl() {
143+
context_impl &event_impl::getContextImpl() {
144144
initContextIfNeeded();
145-
return MContext;
145+
assert(MContext && "Trying to get context from a host event!");
146+
return *MContext;
146147
}
147148

148149
const AdapterPtr &event_impl::getAdapter() {
@@ -152,9 +153,17 @@ const AdapterPtr &event_impl::getAdapter() {
152153

153154
void event_impl::setStateIncomplete() { MState = HES_NotComplete; }
154155

155-
void event_impl::setContextImpl(const ContextImplPtr &Context) {
156+
void event_impl::setContextImpl(std::shared_ptr<context_impl> &&Context) {
156157
MIsHostEvent = Context == nullptr;
157-
MContext = Context;
158+
MContext = std::move(Context);
159+
}
160+
void event_impl::setContextImpl(context_impl &Context) {
161+
MIsHostEvent = false;
162+
MContext = Context.shared_from_this();
163+
}
164+
void event_impl::setContextImpl(context_impl *Context) {
165+
MIsHostEvent = Context == nullptr;
166+
MContext = Context ? Context->shared_from_this() : nullptr;
158167
}
159168

160169
event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext,
@@ -178,7 +187,7 @@ event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext,
178187
event_impl::event_impl(queue_impl &Queue, private_tag)
179188
: MQueue{Queue.weak_from_this()},
180189
MIsProfilingEnabled{Queue.MIsProfilingEnabled} {
181-
this->setContextImpl(Queue.getContextImplPtr());
190+
this->setContextImpl(Queue.getContextImpl());
182191
MState.store(HES_Complete);
183192
}
184193

sycl/source/detail/event_impl.hpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,21 +174,19 @@ class event_impl : public std::enable_shared_from_this<event_impl> {
174174
void setHandle(const ur_event_handle_t &UREvent);
175175

176176
/// Returns context that is associated with this event.
177-
///
178-
/// \return a shared pointer to a valid context_impl.
179-
const ContextImplPtr &getContextImpl();
177+
context_impl &getContextImpl();
180178

181179
/// \return the Adapter associated with the context of this event.
182180
/// Should be called when this is not a Host Event.
183181
const AdapterPtr &getAdapter();
184182

185183
/// Associate event with the context.
186184
///
187-
/// Provided UrContext inside ContextImplPtr must be associated
185+
/// Provided UrContext inside Context must be associated
188186
/// with the UrEvent object stored in this class
189-
///
190-
/// @param Context is a shared pointer to an instance of valid context_impl.
191-
void setContextImpl(const ContextImplPtr &Context);
187+
void setContextImpl(std::shared_ptr<context_impl> &&Context);
188+
void setContextImpl(context_impl &Context);
189+
void setContextImpl(context_impl *Context);
192190

193191
/// Clear the event state
194192
void setStateIncomplete();

sycl/source/detail/graph_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ exec_graph_impl::enqueue(sycl::detail::queue_impl &Queue,
10371037

10381038
auto CreateNewEvent([&]() {
10391039
auto NewEvent = sycl::detail::event_impl::create_device_event(Queue);
1040-
NewEvent->setContextImpl(Queue.getContextImplPtr());
1040+
NewEvent->setContextImpl(Queue.getContextImpl());
10411041
NewEvent->setStateIncomplete();
10421042
return NewEvent;
10431043
});

sycl/source/detail/handler_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ class handler_impl {
199199
template <typename Self = handler_impl> context_impl &get_context() {
200200
Self *self = this;
201201
if (auto *Queue = self->get_queue_or_null())
202-
return *Queue->getContextImplPtr();
202+
return Queue->getContextImpl();
203203
else
204204
return *self->get_graph().getContextImplPtr();
205205
}

sycl/source/detail/helpers.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
namespace sycl {
2626
inline namespace _V1 {
27-
using ContextImplPtr = std::shared_ptr<sycl::detail::context_impl>;
2827
namespace detail {
2928
void waitEvents(std::vector<sycl::event> DepEvents) {
3029
for (auto SyclEvent : DepEvents) {
@@ -59,10 +58,10 @@ retrieveKernelBinary(queue_impl &Queue, KernelNameStrRefT KernelName,
5958
if (DeviceImage == DeviceImages.end()) {
6059
return {nullptr, nullptr};
6160
}
62-
auto ContextImpl = Queue.getContextImplPtr();
61+
context_impl &ContextImpl = Queue.getContextImpl();
6362
ur_program_handle_t Program =
6463
detail::ProgramManager::getInstance().createURProgram(
65-
**DeviceImage, *ContextImpl, {createSyclObjFromImpl<device>(Dev)});
64+
**DeviceImage, ContextImpl, {createSyclObjFromImpl<device>(Dev)});
6665
return {*DeviceImage, Program};
6766
}
6867

@@ -80,11 +79,11 @@ retrieveKernelBinary(queue_impl &Queue, KernelNameStrRefT KernelName,
8079
DeviceImage = SyclKernelImpl->getDeviceImage()->get_bin_image_ref();
8180
Program = SyclKernelImpl->getDeviceImage()->get_ur_program_ref();
8281
} else {
83-
auto ContextImpl = Queue.getContextImplPtr();
82+
context_impl &ContextImpl = Queue.getContextImpl();
8483
DeviceImage = &detail::ProgramManager::getInstance().getDeviceImage(
85-
KernelName, *ContextImpl, &Dev);
84+
KernelName, ContextImpl, &Dev);
8685
Program = detail::ProgramManager::getInstance().createURProgram(
87-
*DeviceImage, *ContextImpl, {createSyclObjFromImpl<device>(Dev)});
86+
*DeviceImage, ContextImpl, {createSyclObjFromImpl<device>(Dev)});
8887
}
8988
return {DeviceImage, Program};
9089
}

sycl/source/detail/kernel_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class kernel_impl {
232232
bool isInterop() const { return MIsInterop; }
233233

234234
ur_program_handle_t getProgramRef() const { return MProgram; }
235-
ContextImplPtr getContextImplPtr() const { return MContext; }
235+
context_impl &getContextImpl() const { return *MContext; }
236236

237237
std::mutex &getNoncacheableEnqueueMutex() const {
238238
return MNoncacheableEnqueueMutex;

sycl/source/detail/queue_impl.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ template <> device queue_impl::get_info<info::queue::device>() const {
8080
template <>
8181
typename info::platform::version::return_type
8282
queue_impl::get_backend_info<info::platform::version>() const {
83-
if (getContextImplPtr()->getBackend() != backend::opencl) {
83+
if (getContextImpl().getBackend() != backend::opencl) {
8484
throw sycl::exception(errc::backend_mismatch,
8585
"the info::platform::version info descriptor can "
8686
"only be queried with an OpenCL backend");
@@ -93,7 +93,7 @@ queue_impl::get_backend_info<info::platform::version>() const {
9393
template <>
9494
typename info::device::version::return_type
9595
queue_impl::get_backend_info<info::device::version>() const {
96-
if (getContextImplPtr()->getBackend() != backend::opencl) {
96+
if (getContextImpl().getBackend() != backend::opencl) {
9797
throw sycl::exception(errc::backend_mismatch,
9898
"the info::device::version info descriptor can only "
9999
"be queried with an OpenCL backend");
@@ -106,7 +106,7 @@ queue_impl::get_backend_info<info::device::version>() const {
106106
template <>
107107
typename info::device::backend_version::return_type
108108
queue_impl::get_backend_info<info::device::backend_version>() const {
109-
if (getContextImplPtr()->getBackend() != backend::ext_oneapi_level_zero) {
109+
if (getContextImpl().getBackend() != backend::ext_oneapi_level_zero) {
110110
throw sycl::exception(errc::backend_mismatch,
111111
"the info::device::backend_version info descriptor "
112112
"can only be queried with a Level Zero backend");
@@ -121,7 +121,7 @@ queue_impl::get_backend_info<info::device::backend_version>() const {
121121
static event prepareSYCLEventAssociatedWithQueue(
122122
const std::shared_ptr<detail::queue_impl> &QueueImpl) {
123123
auto EventImpl = detail::event_impl::create_device_event(*QueueImpl);
124-
EventImpl->setContextImpl(detail::getSyclObjImpl(QueueImpl->get_context()));
124+
EventImpl->setContextImpl(QueueImpl->getContextImpl());
125125
EventImpl->setStateIncomplete();
126126
return detail::createSyclObjFromImpl<event>(EventImpl);
127127
}
@@ -731,7 +731,7 @@ ur_native_handle_t queue_impl::getNative(int32_t &NativeHandleDesc) const {
731731

732732
Adapter->call<UrApiKind::urQueueGetNativeHandle>(MQueue, &UrNativeDesc,
733733
&Handle);
734-
if (getContextImplPtr()->getBackend() == backend::opencl)
734+
if (getContextImpl().getBackend() == backend::opencl)
735735
__SYCL_OCL_CALL(clRetainCommandQueue, ur::cast<cl_command_queue>(Handle));
736736

737737
return Handle;

sycl/source/detail/queue_impl.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
293293

294294
const AdapterPtr &getAdapter() const { return MContext->getAdapter(); }
295295

296+
// TODO: stop using it in existing code. New code must NOT use this!
296297
const ContextImplPtr &getContextImplPtr() const { return MContext; }
297298

298299
context_impl &getContextImpl() const { return *MContext; }
@@ -651,7 +652,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
651652
void revisitUnenqueuedCommandsState(const EventImplPtr &CompletedHostTask);
652653

653654
static ContextImplPtr getContext(const QueueImplPtr &Queue) {
654-
return Queue ? Queue->getContextImplPtr() : nullptr;
655+
return Queue ? Queue->getContextImpl().shared_from_this() : nullptr;
655656
}
656657

657658
// Must be called under MMutex protection
@@ -977,7 +978,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
977978
mutable std::mutex MMutex;
978979

979980
device_impl &MDevice;
980-
const ContextImplPtr MContext;
981+
const std::shared_ptr<context_impl> MContext;
981982

982983
/// These events are tracked, but not owned, by the queue.
983984
std::vector<std::weak_ptr<event_impl>> MEventsWeak;

sycl/source/detail/reduction.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ __SYCL_EXPORT void
208208
addCounterInit(handler &CGH, std::shared_ptr<sycl::detail::queue_impl> &Queue,
209209
std::shared_ptr<int> &Counter) {
210210
auto EventImpl = detail::event_impl::create_device_event(*Queue);
211-
EventImpl->setContextImpl(detail::getSyclObjImpl(Queue->get_context()));
211+
EventImpl->setContextImpl(Queue->getContextImpl());
212212
EventImpl->setStateIncomplete();
213213
ur_event_handle_t UREvent = nullptr;
214214
MemoryManager::fill_usm(Counter.get(), *Queue, sizeof(int), {0}, {},

sycl/source/detail/sampler_impl.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,14 @@ sampler_impl::~sampler_impl() {
9595
}
9696

9797
ur_sampler_handle_t
98-
sampler_impl::getOrCreateSampler(const ContextImplPtr &ContextImpl) {
98+
sampler_impl::getOrCreateSampler(context_impl &ContextImpl) {
99+
// Just for the `MContextToSampler` lookups. Probably the type of it should be
100+
// changed.
101+
std::shared_ptr<context_impl> ContextImplPtr = ContextImpl.shared_from_this();
102+
99103
{
100104
std::lock_guard<std::mutex> Lock(MMutex);
101-
auto It = MContextToSampler.find(ContextImpl);
105+
auto It = MContextToSampler.find(ContextImplPtr);
102106
if (It != MContextToSampler.end())
103107
return It->second;
104108
}
@@ -135,18 +139,18 @@ sampler_impl::getOrCreateSampler(const ContextImplPtr &ContextImpl) {
135139

136140
ur_result_t errcode_ret = UR_RESULT_SUCCESS;
137141
ur_sampler_handle_t resultSampler = nullptr;
138-
const AdapterPtr &Adapter = ContextImpl->getAdapter();
142+
const AdapterPtr &Adapter = ContextImpl.getAdapter();
139143

140144
errcode_ret = Adapter->call_nocheck<UrApiKind::urSamplerCreate>(
141-
ContextImpl->getHandleRef(), &desc, &resultSampler);
145+
ContextImpl.getHandleRef(), &desc, &resultSampler);
142146

143147
if (errcode_ret == UR_RESULT_ERROR_UNSUPPORTED_FEATURE)
144148
throw sycl::exception(sycl::errc::feature_not_supported,
145149
"Images are not supported by this device.");
146150

147151
Adapter->checkUrResult(errcode_ret);
148152
std::lock_guard<std::mutex> Lock(MMutex);
149-
MContextToSampler[ContextImpl] = resultSampler;
153+
MContextToSampler[ContextImplPtr] = resultSampler;
150154

151155
return resultSampler;
152156
}

sycl/source/detail/sampler_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class sampler_impl {
4646

4747
coordinate_normalization_mode get_coordinate_normalization_mode() const;
4848

49-
ur_sampler_handle_t getOrCreateSampler(const ContextImplPtr &ContextImpl);
49+
ur_sampler_handle_t getOrCreateSampler(context_impl &ContextImpl);
5050

5151
~sampler_impl();
5252

0 commit comments

Comments
 (0)