Skip to content

Add support for device kwarg in astype, and add matching utility func #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Remember to align the itemized text with the first line of an item within a list
implementation of cross-process collective operations used by the CPU backend.
Choices available are `'none'`(default), `'gloo'` and `'mpi'` (requires jaxlib 0.4.26).
If set to `'none'`, cross-process collective operations are disabled.
* New Features
* {func}`jax.numpy.astype` supports new `device` keyword argument.

* Changes
* {func}`jax.pure_callback`, {func}`jax.experimental.io_callback`
Expand Down
19 changes: 3 additions & 16 deletions jax/_src/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from jax._src.lib import xla_extension_version
from jax._src.typing import Array, DLDeviceType
from jax._src.sharding import Sharding
from jax._src.numpy.util import _place_array

DLPACK_VERSION = (0, 8)
MIN_DLPACK_VERSION = (0, 5)
Expand Down Expand Up @@ -152,19 +153,6 @@ def to_dlpack(x: Array, stream: int | Any | None = None,
f"version ({max_version}) was requested."
)

def _place_array(_arr, device, dlpack_device, copy):
if device and dlpack_device != device:
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy since the source device "
f"is {repr(dlpack_device)}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
else:
return device_put(_arr, device)
if copy:
return jnp.array(_arr, copy=True)
return _arr

def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None,
copy: bool | None = None):
Expand Down Expand Up @@ -198,8 +186,7 @@ def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None,

_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend)) # type: ignore
dlpack_device, = _arr.devices()
return _place_array(_arr, device, dlpack_device, copy)
return _place_array(_arr, device, copy)

def _from_dlpack(external_array, device: xla_client.Device | None = None,
copy: bool | None = None):
Expand Down Expand Up @@ -230,7 +217,7 @@ def _from_dlpack(external_array, device: xla_client.Device | None = None,

_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, dlpack_device, stream))
return _place_array(_arr, device, dlpack_device, copy)
return _place_array(_arr, device, copy)

def from_dlpack(external_array,
device: xla_client.Device | Sharding | None = None,
Expand Down
14 changes: 5 additions & 9 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2636,18 +2636,14 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
# to issue our warning.
with warnings.catch_warnings():
warnings.simplefilter("ignore", ComplexWarning)
return _place_array(
return util._place_array(
lax.convert_element_type(x_arr, dtype),
device=device, copy=copy,
device=device,
# We translate between array API semantics of copy in _place_array, and
# the NumPy semantics of copy in astype.
copy=True if copy else None,
)

def _place_array(x, device=None, copy=None):
# TODO(micky774): Implement in future PRs as we formalize device placement
# semantics
if copy:
return _array_copy(x)
return x


@util.implements(np.asarray, lax_description=_ARRAY_DOC)
def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
Expand Down
44 changes: 42 additions & 2 deletions jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
from functools import partial
import re
import textwrap
from typing import Any, Callable, NamedTuple, TypeVar

from typing import Any, Callable, NamedTuple, TypeVar, Set
import warnings

from jax.sharding import Sharding

from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.lib import xla_client as xc
from jax._src.util import safe_zip, safe_map
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape

Expand Down Expand Up @@ -117,6 +119,44 @@ def _parse_extra_params(extra_params: str) -> dict[str, str]:
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}


def _get_device_set(x: ArrayLike | xc.Device | Sharding | None) -> Set[xc.Device] | None:
if x is None or isinstance(x, core.Tracer):
return None
elif isinstance(x, Sharding):
return x.device_set
elif isinstance(x, xc.Device):
return {x}
elif hasattr(x, "devices"):
return x.devices()
else:
raise ValueError(f"Attempted to get a device set from an unsupported type: {type(x)}")


def _place_array(x: Array, device: xc.Device | Sharding | None = None, copy=None) -> Array:
"""Helper utility for copying an array, or placing it on a device or sharding.

Note that `device_put` is regarded as a no-op under JIT compilation, so we
ensure array API device semantics compliance only under eager execution,
favoring JIT compilation performance over correctness in this case.
"""

devices = _get_device_set(device)
src_devices = _get_device_set(x)
if devices is not None and src_devices != devices:
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy since the source devices "
f"are {src_devices}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
out = api.device_put(x, device)
else:
out = x
if copy:
return lax._array_copy(out)
return out


def implements(
original_fun: Callable[..., Any] | None,
update_doc: bool = True,
Expand Down
15 changes: 11 additions & 4 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3873,14 +3873,21 @@ def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'):
@jtu.sample_product(
change_dtype=[True, False],
copy=[True, False],
change_device=[True, False],
)
def testAstypeCopy(self, change_dtype, copy):
@jtu.run_on_devices("gpu")
def testAstypeCopy(self, change_dtype, copy, change_device):
dtype = 'float32' if change_dtype else 'int32'
expect_copy = change_dtype or copy
device = jax.devices("cpu")[0] if change_device else None
expect_copy = change_dtype or copy or change_device
x = jnp.arange(5, dtype='int32')
y = x.astype(dtype, copy=copy)

y = x.astype(dtype, copy=copy, device=device)
self.assertEqual(y.dtype, dtype)

placed_devices = y.devices()
expeceted_devices = {device} if change_device else x.devices()
self.assertEqual(placed_devices, expeceted_devices)

y.delete()
self.assertNotEqual(x.is_deleted(), expect_copy)

Expand Down