Skip to content

Commit 1db1717

Browse files
committed
improvements & fixes
1 parent 8e208c9 commit 1db1717

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

src/array_api_stubs/_draft/array_object.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,9 @@ def __dlpack__(
300300
self: array
301301
array instance.
302302
stream: Optional[Union[int, Any]]
303-
for CUDA and ROCm, a Python integer representing a pointer to a stream, on devices that support streams. ``stream`` is provided by the consumer to the producer to instruct the producer to ensure that operations can safely be performed on the array (e.g., by inserting a dependency between streams via "wait for event"). The pointer must be a positive integer or ``-1``. If ``stream`` is ``-1``, the value may be used by the consumer to signal "producer must not perform any synchronization". The ownership of the stream stays with the consumer. On CPU and other device types without streams, only ``None`` is accepted.
303+
for CUDA and ROCm, a Python integer representing a pointer to a stream, on devices that support streams. ``stream`` is provided by the consumer to the producer to instruct the producer to ensure that operations can safely be performed on the array (e.g., by inserting a dependency between streams via "wait for event"). The pointer must be an integer larger than or equal to ``-1`` (see below for allowed values on each platform). If ``stream`` is ``-1``, the value may be used by the consumer to signal "producer must not perform any synchronization". The ownership of the stream stays with the consumer. On CPU and other device types without streams, only ``None`` is accepted.
304304
305-
For other device types which do have a stream, queue or similar synchronization mechanism, the most appropriate type to use for ``stream`` is not yet determined. E.g., for SYCL one may want to use an object containing an in-order ``cl::sycl::queue``. This is allowed when libraries agree on such a convention, and may be standardized in a future version of this API standard.
305+
For other device types which do have a stream, queue, or similar synchronization/ordering mechanism, the most appropriate type to use for ``stream`` is not yet determined. E.g., for SYCL one may want to use an object containing an in-order ``cl::sycl::queue``. This is allowed when libraries agree on such a convention, and may be standardized in a future version of this API standard.
306306
307307
.. note::
308308
Support for a ``stream`` value other than ``None`` is optional and implementation-dependent.
@@ -329,12 +329,12 @@ def __dlpack__(
329329
they use the legacy default stream, specifying ``1`` (CUDA) or ``0``
330330
(ROCm) is preferred. ``None`` is a safe default for developers who do
331331
not want to think about stream handling at all, potentially at the
332-
cost of more synchronization than necessary.
332+
cost of more synchronizations than necessary.
333333
max_version: Optional[tuple[int, int]]
334-
The maximum DLPack version that the consumer (i.e., the caller of
335-
``__dlpack__``) supports, in the form ``(major, minor)``.
336-
This method may return that maximum version (recommended if it does
337-
support that), or a different version.
334+
The maximum DLPack version that the *consumer* (i.e., the caller of
335+
``__dlpack__``) supports, in the form of a 2-tuple ``(major, minor)``.
336+
This method may return a capsule of version ``max_version`` (recommended
337+
if it does support that), or of a different version.
338338
339339
Returns
340340
-------
@@ -351,14 +351,17 @@ def __dlpack__(
351351
352352
Notes
353353
-----
354-
Major DLPack versions represent ABI breaks, minor versions represent
355-
ABI-compatible additions (e.g., new enum values for new data types or
356-
device types).
354+
The DLPack version scheme is SemVer, where the major DLPack versions
355+
represent ABI breaks, and minor versions represent ABI-compatible additions
356+
(e.g., new enum values for new data types or device types).
357357
358358
The ``max_version`` keyword was introduced in v2023.12, and goes
359359
together with the ``DLManagedTensorVersioned`` struct added in DLPack
360-
1.0. This keyword may not be used by consumers for some time after
361-
introduction. It is recommended to use this logic in the implementation
360+
1.0. This keyword may not be used by consumers until a later time after
361+
introduction, because producers may implement the support at a different
362+
point in time.
363+
364+
It is recommended for the producer to use this logic in the implementation
362365
of ``__dlpack__``:
363366
364367
.. code:: python
@@ -368,7 +371,10 @@ def __dlpack__(
368371
# Note: from March 2025 onwards (but ideally as late as
369372
# possible), it's okay to raise BufferError here
370373
else:
371-
# We get to produce `DLManagedTensorVersioned` now
374+
# We get to produce `DLManagedTensorVersioned` now. Note that
375+
# our_own_dlpack_version is the max version that the *producer*
376+
# supports and fills in the `DLManagedTensorVersioned::version`
377+
# field
372378
if max_version >= our_own_dlpack_version:
373379
# Consumer understands us, just return a Capsule with our max version
374380
elif max_version[0] == our_own_dlpack_version[0]:
@@ -377,18 +383,21 @@ def __dlpack__(
377383
else:
378384
# if we're at a higher major version internally, did we
379385
# keep an implementation of the older major version around?
380-
# If so, use that. Else, just return our max
381-
# version and let the consumer deal with it.
386+
# For example, if the producer is on DLPack 1.x and the consumer
387+
# is 0.y, can the producer still export a capsule containing
388+
# DLManagedTensor and not DLManagedTensorVersioned?
389+
# If so, use that. Else, the producer should raise a BufferError
390+
# here to tell users that the consumer's max_version is too
391+
# old to allow the data exchange to happen.
382392
383-
And this logic for the producer (i.e., in ``from_dlpack``):
393+
And this logic for the consumer in ``from_dlpack``:
384394
385395
.. code:: python
386396
387397
try:
388398
x.__dlpack__(max_version=(1, 0))
389-
# if it succeeds, store info about capsule name being "dltensor_versioned",
390-
# and needing to set the capsule name to "used_dltensor_versioned"
391-
# when we're done
399+
# if it succeeds, store info from the capsule named "dltensor_versioned",
400+
# and need to set the name to "used_dltensor_versioned" when we're done
392401
except TypeError:
393402
x.__dlpack__()
394403

0 commit comments

Comments
 (0)