Skip to content

Commit 5fbd2c9

Browse files
danielsuoChexDev
authored andcommitted
[pmap] In-line definitions of jax.device_put_sharded and jax.device_put_replicated.
Both `jax.device_put_sharded` and `jax.device_put_replicated` were deprecated in JAX v0.8.1 in November 2025. We in-line their definitions using public JAX APIs taking the `jax_pmap_shmap_merge=True` branch, which was made the default in JAX v0.8.0 in October 2025. Please see the below for more information: - JAX CHANGELOG: https://docs.jax.dev/en/latest/changelog.html - Migrating from `jax.pmap`: https://docs.jax.dev/en/latest/migrate_pmap.html PiperOrigin-RevId: 897706398
1 parent 665905f commit 5fbd2c9

3 files changed

Lines changed: 37 additions & 11 deletions

File tree

chex/_src/asserts_chexify_test.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,18 @@
3737
chexify_sync = functools.partial(asserts_chexify.chexify, async_check=False)
3838

3939

40+
def _device_put_replicated(x, devices):
41+
mesh = jax.sharding.Mesh(
42+
np.array(devices), axis_names=('_device_put_replicated',)
43+
)
44+
sharding = jax.sharding.NamedSharding(
45+
mesh, jax.sharding.PartitionSpec('_device_put_replicated')
46+
)
47+
return jax.tree_util.tree_map(
48+
lambda v: jax.device_put(np.stack([v] * len(devices)), sharding), x
49+
)
50+
51+
4052
def get_chexify_err_regex(name, msg):
4153
return re.escape(_ai.get_chexify_err_message(name, 'ANY')).replace(
4254
'ANY', f'.*{msg}.*'
@@ -401,8 +413,7 @@ def run_test_suite_with_log_abs_fn(self, make_log_fn, jax_transform, devices,
401413
'c': np.array([[5, -1] for _ in range(10)])
402414
}
403415
}
404-
(x_pos, x_with_neg) = jax.device_put_replicated((x_pos, x_with_neg),
405-
devices)
416+
x_pos, x_with_neg = _device_put_replicated((x_pos, x_with_neg), devices)
406417

407418
all_valid_args = ((x_pos, x_pos),)
408419
all_invalid_args = (

chex/_src/asserts_test.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ def emplace(arrays, dtype):
4444
return jnp.array(arrays, dtype=dtype)
4545

4646

47+
def _device_put_replicated(x, devices):
48+
mesh = jax.sharding.Mesh(
49+
np.array(devices), axis_names=('_device_put_replicated',)
50+
)
51+
sharding = jax.sharding.NamedSharding(
52+
mesh, jax.sharding.PartitionSpec('_device_put_replicated')
53+
)
54+
return jax.tree_util.tree_map(
55+
lambda v: jax.device_put(np.stack([v] * len(devices)), sharding), x
56+
)
57+
58+
4759
class AssertsSwitchTest(parameterized.TestCase):
4860
"""Tests for enable/disable_asserts."""
4961

@@ -1214,7 +1226,7 @@ def test_assert_tree_is_on_host(self):
12141226

12151227
# Check sharded Jax arrays on CPUs.
12161228
asserts.assert_tree_is_on_host(
1217-
{'a': jax.device_put_replicated(np.zeros(1), (cpu,))},
1229+
{'a': _device_put_replicated(np.zeros(1), (cpu,))},
12181230
allow_cpu_device=True,
12191231
allow_sharded_arrays=True,
12201232
)
@@ -1242,7 +1254,7 @@ def test_assert_tree_is_on_host(self):
12421254
AssertionError, _get_err_regex("'a' resides on.*CPU.*disallowed")
12431255
):
12441256
asserts.assert_tree_is_on_host(
1245-
{'a': jax.device_put_replicated(np.zeros(1), (cpu,))},
1257+
{'a': _device_put_replicated(np.zeros(1), (cpu,))},
12461258
allow_cpu_device=False,
12471259
)
12481260

@@ -1251,7 +1263,7 @@ def test_assert_tree_is_on_host(self):
12511263
AssertionError, _get_err_regex("'a' resides on.*CPU.*disallowed")
12521264
):
12531265
asserts.assert_tree_is_on_host(
1254-
{'a': jax.device_put_replicated(np.zeros(1), (cpu,))},
1266+
{'a': _device_put_replicated(np.zeros(1), (cpu,))},
12551267
allow_cpu_device=False,
12561268
allow_sharded_arrays=True,
12571269
)
@@ -1352,7 +1364,7 @@ def _format(*devs):
13521364
# a "sharded" array (device_set has length 1). The array is treated as a
13531365
# regular single-device array.
13541366
cpu = jax.local_devices(backend='cpu')[0]
1355-
cpu_tree = jax.device_put_replicated(np_tree, (cpu,))
1367+
cpu_tree = _device_put_replicated(np_tree, (cpu,))
13561368

13571369
# Single-device array is NOT considered sharded.
13581370
with self.assertRaisesRegex(
@@ -1366,10 +1378,10 @@ def _format(*devs):
13661378
if _num_devices_available('tpu') > 1:
13671379
tpu_1, tpu_2 = jax.devices('tpu')[:2]
13681380

1369-
tpu_1_tree = jax.device_put_replicated(np_tree, (tpu_1,))
1370-
tpu_2_tree = jax.device_put_replicated(np_tree, (tpu_2,))
1371-
tpu_1_2_tree = jax.device_put_replicated(np_tree, (tpu_1, tpu_2))
1372-
tpu_2_1_tree = jax.device_put_replicated(np_tree, (tpu_2, tpu_1))
1381+
tpu_1_tree = _device_put_replicated(np_tree, (tpu_1,))
1382+
tpu_2_tree = _device_put_replicated(np_tree, (tpu_2,))
1383+
tpu_1_2_tree = _device_put_replicated(np_tree, (tpu_1, tpu_2))
1384+
tpu_2_1_tree = _device_put_replicated(np_tree, (tpu_2, tpu_1))
13731385

13741386
asserts.assert_tree_is_sharded(tpu_1_2_tree, devices=(tpu_1, tpu_2))
13751387
asserts.assert_tree_is_sharded(tpu_2_1_tree, devices=(tpu_2, tpu_1))

chex/_src/variants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import jax
2929
from jax import tree_util
3030
import jax.numpy as jnp
31+
import numpy as np
3132
import toolz
3233

3334
FLAGS = flags.FLAGS
@@ -550,7 +551,9 @@ def bcast_fn(x):
550551
x = jnp.asarray(x)
551552
x = jnp.broadcast_to(x, (n_devices_,) + x.shape)
552553
if not isinstance(x, jax.core.Tracer):
553-
return jax.device_put_sharded(list(x), devices_)
554+
mesh = jax.sharding.Mesh(np.array(devices_), ("_device_put_sharded",))
555+
sharding = jax.NamedSharding(mesh, jax.P("_device_put_sharded"))
556+
return jax.device_put(jnp.stack(list(x)), sharding)
554557
return x
555558

556559
if broadcast_args_to_devices:

0 commit comments

Comments
 (0)