Skip to content

Commit 1496cb7

Browse files
ChexDevChexDev
authored andcommitted
Sort devices explicitly by process index, then id (as opposed to IDs alone). IDs may be randomly generated, and are not guaranteed to be ordered based on their process index.
PiperOrigin-RevId: 650294623
1 parent 343d03a commit 1496cb7

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

chex/_src/asserts_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,12 +1200,14 @@ def test_assert_tree_is_on_device(self):
12001200
asserts.assert_tree_is_on_device(tpu_2_tree, device=tpu_2)
12011201

12021202
with self.assertRaisesRegex(
1203-
AssertionError, _get_err_regex(r"'a' resides on.*TpuDevice\(id=0")
1203+
AssertionError,
1204+
_get_err_regex(r"'a' resides on.*TpuDevice\(process_index=0, id=0"),
12041205
):
12051206
asserts.assert_tree_is_on_device(tpu_1_tree, device=tpu_2)
12061207

12071208
with self.assertRaisesRegex(
1208-
AssertionError, _get_err_regex(r"'a' resides on.*TpuDevice\(id=1")
1209+
AssertionError,
1210+
_get_err_regex(r"'a' resides on.*TpuDevice\(process_index=0, id=1"),
12091211
):
12101212
asserts.assert_tree_is_on_device(tpu_2_tree, device=tpu_1)
12111213

@@ -1735,7 +1737,7 @@ def test_assert_equal_pass(self, first, second):
17351737
asserts.assert_equal(first, second)
17361738

17371739
def test_assert_equal_pass_on_arrays(self):
1738-
# Not using named_parameters, becase JAX cannot be used before app.run().
1740+
# Not using named_parameters, because JAX cannot be used before app.run().
17391741
asserts.assert_equal(jnp.ones([]), np.ones([]))
17401742
asserts.assert_equal(
17411743
jnp.ones([], dtype=jnp.int32), np.ones([], dtype=np.float64))

0 commit comments

Comments
 (0)