Skip to content

Commit 7e7044a

Browse files
committed
Add support for device and copy kwargs in from_dlpack to match array API
1 parent fed7efd commit 7e7044a

File tree

5 files changed

+137
-57
lines changed

5 files changed

+137
-57
lines changed

jax/_src/dlpack.py

Lines changed: 112 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
from typing import Any
1919
import warnings
2020

21+
from jax._src.api import device_put
2122
from jax import numpy as jnp
2223
from jax._src import array
2324
from jax._src import xla_bridge
2425
from jax._src.lib import xla_client
2526
from jax._src.lib import xla_extension_version
2627
from jax._src.typing import Array
27-
28+
from jax._src.sharding import Sharding
2829

2930
# A set of dtypes that dlpack supports.
3031
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
@@ -82,16 +83,112 @@ def to_dlpack(x: Array, take_ownership: bool = False,
8283
x.addressable_data(0), stream=stream
8384
) # type: ignore
8485

86+
def _place_array(_arr, device, dlpack_device, copy):
87+
if device and dlpack_device != device:
88+
if copy is not None and not copy:
89+
raise ValueError(
90+
f"Specified {device=} which requires a copy since the source device "
91+
f"is {repr(dlpack_device)}, however copy=False. Set copy=True or "
92+
"copy=None to perform the requested operation."
93+
)
94+
else:
95+
return device_put(_arr, device)
96+
if copy:
97+
return jnp.array(_arr, copy=True)
98+
return _arr
99+
100+
def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None,
101+
copy: bool | None = None):
102+
preferred_platform = getattr(device, "platform", None)
103+
if device and preferred_platform == "gpu":
104+
preferred_platform = "cuda" if "cuda" in device.client.platform_version else "rocm"
105+
106+
cpu_backend = xla_bridge.get_backend("cpu")
107+
gpu_backend = None
108+
109+
if preferred_platform in {"cuda", "rocm"}:
110+
try:
111+
gpu_backend = xla_bridge.get_backend(preferred_platform)
112+
except RuntimeError:
113+
raise TypeError(
114+
f"A {str.upper(preferred_platform)} device was specified, however no "
115+
f"{str.upper(preferred_platform)} backend was found."
116+
)
85117

86-
def from_dlpack(external_array):
118+
if preferred_platform is None:
119+
try:
120+
gpu_backend = xla_bridge.get_backend("cuda")
121+
except RuntimeError:
122+
pass
123+
# Try ROCm if CUDA backend not found
124+
if gpu_backend is None:
125+
try:
126+
gpu_backend = xla_bridge.get_backend("rocm")
127+
except RuntimeError:
128+
pass
129+
130+
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
131+
dlpack, cpu_backend, gpu_backend)) # type: ignore
132+
133+
return _place_array(_arr, device, _arr.devices().pop(), copy)
134+
135+
def _from_dlpack(external_array, device: xla_client.Device | None = None,
136+
copy: bool | None = None):
137+
dl_device_type, device_id = external_array.__dlpack_device__()
138+
try:
139+
dl_device_platform = {
140+
DLDeviceType.kDLCPU: "cpu",
141+
DLDeviceType.kDLCUDA: "cuda",
142+
DLDeviceType.kDLROCM: "rocm",
143+
}[dl_device_type]
144+
except TypeError:
145+
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
146+
# TypeError.
147+
raise TypeError(
148+
"Array passed to from_dlpack is on unsupported device type "
149+
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
150+
151+
backend = xla_bridge.get_backend(dl_device_platform)
152+
dlpack_device = backend.device_from_local_hardware_id(device_id)
153+
try:
154+
stream = dlpack_device.get_stream_for_external_ready_events()
155+
except xla_client.XlaRuntimeError as err: # type: ignore
156+
if "UNIMPLEMENTED" in str(err):
157+
stream = None
158+
else:
159+
raise
160+
dlpack = external_array.__dlpack__(stream=stream)
161+
162+
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
163+
dlpack, dlpack_device, stream))
164+
165+
return _place_array(_arr, device, dlpack_device, copy)
166+
167+
def from_dlpack(external_array,
168+
device: xla_client.Device | Sharding | None = None,
169+
copy: bool | None = None):
87170
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
88171
89-
The returned :class:`~jax.Array` shares memory with ``external_array``.
172+
The returned :class:`~jax.Array` shares memory with ``external_array`` if no
173+
device transfer or copy was requested.
90174
91175
Args:
92-
external_array: an array object that has __dlpack__ and __dlpack_device__
176+
external_array: An array object that has __dlpack__ and __dlpack_device__
93177
methods, or a DLPack tensor on either CPU or GPU (legacy API).
94178
179+
device: The (optional) :py:class:`Device`, representing the device on which
180+
the returned array should be placed. If given, then the result is committed
181+
to the device. If unspecified, the resulting array will be unpacked onto the
182+
same device it originated from. Setting ``device`` to a device different from
183+
the source of ``external_array`` will require a copy, meaning ``copy`` must be
184+
set to either ``True`` or ``None``.
185+
186+
copy: An (optional) boolean, controlling whether or not to a copy is performed.
187+
If ``copy=True`` then a copy is always performed, even if unpacked onto the
188+
same device. If ``copy=False`` then the copy is never peformed and will raise
189+
an error if necessary. When ``copy=None`` then a copy may be performed if
190+
needed for a device transfer.
191+
95192
Returns:
96193
A jax.Array
97194
@@ -102,49 +199,16 @@ def from_dlpack(external_array):
102199
is later modified in-place, it may lead to undefined behavior when using
103200
the associated JAX array.
104201
"""
202+
if isinstance(device, Sharding):
203+
device_set = device.device_set
204+
if len(device_set) > 1:
205+
raise ValueError(
206+
"from_dlpack can only unpack a dlpack tensor onto a singular device, but "
207+
f"a Sharding with {len(device_set)} devices was provided."
208+
)
209+
device = device_set.pop()
105210
if hasattr(external_array, "__dlpack__"):
106-
dl_device_type, device_id = external_array.__dlpack_device__()
107-
try:
108-
device_platform = {
109-
DLDeviceType.kDLCPU: "cpu",
110-
DLDeviceType.kDLCUDA: "cuda",
111-
DLDeviceType.kDLROCM: "rocm",
112-
}[dl_device_type]
113-
except TypeError:
114-
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
115-
# TypeError.
116-
raise TypeError(
117-
"Array passed to from_dlpack is on unsupported device type "
118-
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
119-
120-
backend = xla_bridge.get_backend(device_platform)
121-
device = backend.device_from_local_hardware_id(device_id)
122-
try:
123-
stream = device.get_stream_for_external_ready_events()
124-
except xla_client.XlaRuntimeError as err: # type: ignore
125-
if "UNIMPLEMENTED" in str(err):
126-
stream = None
127-
else:
128-
raise
129-
dlpack = external_array.__dlpack__(stream=stream)
130-
131-
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
132-
dlpack, device, stream))
133-
else:
134-
# Legacy path
135-
dlpack = external_array
136-
cpu_backend = xla_bridge.get_backend("cpu")
137-
try:
138-
gpu_backend = xla_bridge.get_backend("cuda")
139-
except RuntimeError:
140-
gpu_backend = None
141-
142-
# Try ROCm if CUDA backend not found
143-
if gpu_backend is None:
144-
try:
145-
gpu_backend = xla_bridge.get_backend("rocm")
146-
except RuntimeError:
147-
gpu_backend = None
211+
return _from_dlpack(external_array, device, copy)
148212

149-
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
150-
dlpack, cpu_backend, gpu_backend))
213+
# Legacy path
214+
return _legacy_from_dlpack(external_array, device, copy)

jax/_src/numpy/lax_numpy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2442,9 +2442,10 @@ def fromiter(*args, **kwargs):
24422442
is later modified in-place, it may lead to undefined behavior when using
24432443
the associated JAX array.
24442444
""")
2445-
def from_dlpack(x: Any) -> Array:
2445+
def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None,
2446+
copy: bool | None = None) -> Array:
24462447
from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top
2447-
return from_dlpack(x)
2448+
return from_dlpack(x, device=device, copy=copy)
24482449

24492450
@util.implements(np.fromfunction)
24502451
def fromfunction(function: Callable[..., Array], shape: Any,

jax/experimental/array_api/_creation_functions.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import jax
1618
import jax.numpy as jnp
17-
19+
from jax._src.lib import xla_client as xc
20+
from jax._src.sharding import Sharding
1821

1922
def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
2023
return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device)
@@ -31,8 +34,8 @@ def empty_like(x, /, *, dtype=None, device=None):
3134
def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
3235
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)
3336

34-
def from_dlpack(x, /):
35-
return jnp.from_dlpack(x)
37+
def from_dlpack(x, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None):
38+
return jnp.from_dlpack(x, device=device, copy=copy)
3639

3740
def full(shape, fill_value, *, dtype=None, device=None):
3841
return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device)

jax/numpy/__init__.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,8 @@ def fmax(x: ArrayLike, y: ArrayLike, /) -> Array: ...
353353
def fmin(x: ArrayLike, y: ArrayLike, /) -> Array: ...
354354
def fmod(x: ArrayLike, y: ArrayLike, /) -> Array: ...
355355
def frexp(x: ArrayLike, /) -> tuple[Array, Array]: ...
356-
def from_dlpack(x: Any) -> Array: ...
356+
def from_dlpack(x: Any, /, *, device: _Device | None = None,
357+
copy: builtins.bool | None = None) -> Array: ...
357358
def frombuffer(buffer: Union[bytes, Any], dtype: DTypeLike = ...,
358359
count: int = ..., offset: int = ...) -> Array: ...
359360
def fromfile(*args, **kwargs): ...

tests/array_interoperability_test.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,23 @@ def testTensorFlowToJaxInt64(self):
174174
@jtu.sample_product(
175175
shape=all_shapes,
176176
dtype=numpy_dtypes,
177+
gpu=[False, True] if jtu.test_device_matches(["gpu"]) else [False],
178+
copy=[False, True],
177179
)
178-
def testNumpyToJax(self, shape, dtype):
180+
def testNumpyToJax(self, shape, dtype, gpu, copy):
179181
rng = jtu.rand_default(self.rng())
180182
x_np = rng(shape, dtype)
181-
x_jax = jnp.from_dlpack(x_np)
182-
self.assertAllClose(x_np, x_jax)
183+
platform = "gpu" if gpu else "cpu"
184+
device = jax.devices(platform)[0]
185+
_from_dlpack = lambda: jnp.from_dlpack(x_np, device=device, copy=copy)
186+
if gpu and not copy:
187+
self.assertRaisesRegex(
188+
ValueError,
189+
r"Specified .* which requires a copy",
190+
_from_dlpack
191+
)
192+
else:
193+
self.assertAllClose(x_np, _from_dlpack())
183194

184195
@jtu.sample_product(
185196
shape=all_shapes,

0 commit comments

Comments
 (0)