Skip to content

Commit 0083b7a

Browse files
committed
support copy in from_dlpack
1 parent 1db1717 commit 0083b7a

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

src/array_api_stubs/_draft/array_object.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,9 @@ def __complex__(self: array, /) -> complex:
290290
def __dlpack__(
291291
self: array, /, *,
292292
stream: Optional[Union[int, Any]] = None,
293-
max_version: Optional[tuple[int, int]] = None
293+
max_version: Optional[tuple[int, int]] = None,
294+
dl_device: Optional[Tuple[Enum, int]] = None,
295+
copy: Optional[bool] = False
294296
) -> PyCapsule:
295297
"""
296298
Exports the array for consumption by :func:`~array_api.from_dlpack` as a DLPack capsule.
@@ -335,6 +337,17 @@ def __dlpack__(
335337
``__dlpack__``) supports, in the form of a 2-tuple ``(major, minor)``.
336338
This method may return a capsule of version ``max_version`` (recommended
337339
if it does support that), or of a different version.
340+
dl_device: Optional[Tuple[Enum, int]]
341+
The DLPack device type. Default is ``None``, meaning the exported capsule
342+
should be on the same device as ``self`` is. When specified, the format
343+
must follow that of the return value of :meth:`array.__dlpack_device__`.
344+
If the device type cannot be handled by the producer, this function must
345+
raise `BufferError`.
346+
copy: Optional[bool]
347+
Whether or not a copy should be made. Default is ``False`` to enable
348+
zero-copy data exchange. However, a user can request a copy to be made
349+
by the producer (through the consumer's :func:`~array_api.from_dlpack`)
350+
to move data across the library (and/or device) boundary.
338351
339352
Returns
340353
-------
@@ -390,7 +403,7 @@ def __dlpack__(
390403
# here to tell users that the consumer's max_version is too
391404
# old to allow the data exchange to happen.
392405
393-
And this logic for the consumer in ``from_dlpack``:
406+
And this logic for the consumer in :func:`~array_api.from_dlpack`:
394407
395408
.. code:: python
396409
@@ -405,7 +418,7 @@ def __dlpack__(
405418
Added BufferError.
406419
407420
.. versionchanged:: 2023.12
408-
Added the ``max_version`` keyword.
421+
Added the ``max_version``, ``dl_device``, and ``copy`` keywords.
409422
"""
410423

411424
def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
@@ -432,6 +445,8 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
432445
METAL = 8
433446
VPI = 9
434447
ROCM = 10
448+
CUDA_MANAGED = 13
449+
ONE_API = 14
435450
"""
436451

437452
def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:

src/array_api_stubs/_draft/creation_functions.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020

2121
from ._types import (
22+
Any,
2223
List,
2324
NestedSequence,
2425
Optional,
@@ -214,19 +215,36 @@ def eye(
214215
"""
215216

216217

217-
def from_dlpack(x: object, /) -> array:
218+
def from_dlpack(
219+
x: object, /, *,
220+
device: Optional[device] = None,
221+
copy: Optional[bool] = False,
222+
) -> Union[array, Any]:
218223
"""
219224
Returns a new array containing the data from another (array) object with a ``__dlpack__`` method.
220225
221226
Parameters
222227
----------
223228
x: object
224229
input (array) object.
230+
device: Optional[device]
231+
device on which to place the created array. If ``device`` is ``None`` and ``x`` supports DLPack, the output array device must be inferred from ``x``. Default: ``None``.
232+
233+
The v2023.12 standard only mandates that a compliant library must offer a way for ``from_dlpack`` to create an array on CPU (using
234+
the library-chosen way to represent the CPU device - ``kDLCPU`` in DLPack - e.g. a ``"CPU"`` string or a ``Device("CPU")`` object).
235+
If the compliant library does not support the CPU device and needs to outsource to another (compliant) array library, it may do so
236+
with a clear user documentation and/or run-time warning. If a copy must be made to enable this, and ``copy`` is set to ``False``,
237+
the function must raise ``ValueError``.
238+
239+
Other kinds of devices will be considered for standardization in a future version.
240+
copy: Optional[bool]
241+
boolean indicating whether or not to copy the input. If ``True``, the function must always copy. If ``False``, the function must never copy and must raise a ``BufferError`` in case a copy would be necessary (e.g. the producer disallows views). Default: ``False``.
225242
226243
Returns
227244
-------
228-
out: array
229-
an array containing the data in `x`.
245+
out: Union[array, Any]
246+
an array containing the data in ``x``. In the case that the compliant library does not support the given ``device`` out of box
247+
and must oursource to another (compliant) library, the output will be that library's compliant array object.
230248
231249
.. admonition:: Note
232250
:class: note
@@ -238,9 +256,9 @@ def from_dlpack(x: object, /) -> array:
238256
BufferError
239257
The ``__dlpack__`` and ``__dlpack_device__`` methods on the input array
240258
may raise ``BufferError`` when the data cannot be exported as DLPack
241-
(e.g., incompatible dtype or strides). It may also raise other errors
259+
(e.g., incompatible dtype, strides, or device). It may also raise other errors
242260
when export fails for other reasons (e.g., not enough memory available
243-
to materialize the data). ``from_dlpack`` must propagate such
261+
to materialize the data, a copy must made, etc). ``from_dlpack`` must propagate such
244262
exceptions.
245263
AttributeError
246264
If the ``__dlpack__`` and ``__dlpack_device__`` methods are not present
@@ -251,6 +269,9 @@ def from_dlpack(x: object, /) -> array:
251269
-----
252270
See :meth:`array.__dlpack__` for implementation suggestions for `from_dlpack` in
253271
order to handle DLPack versioning correctly.
272+
273+
.. versionchanged:: 2023.12
274+
Added device and copy support.
254275
"""
255276

256277

0 commit comments

Comments
 (0)