Skip to content

Commit 43195de

Browse files
committed
Guard tests that fail in libtpu==0.0.37, but succeed at nightly libtpu==0.0.38.dev20260315.
Preparing for JAX 0.9.2 release. - 0.0.37: https://github.com/jax-ml/jax/actions/runs/23158234447 - 0.0.38.dev20260315: https://github.com/jax-ml/jax/actions/runs/23162037857 PiperOrigin-RevId: 885189173
1 parent 6ca5383 commit 43195de

File tree

4 files changed

+18
-0
lines changed

4 files changed

+18
-0
lines changed

tests/pallas/ops_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,8 @@ def test_sign(self, dtype, value):
962962
self.skipTest("f16 load not supported on TPU")
963963
if dtype in (jnp.int16, jnp.uint16) and jtu.get_tpu_version() < 6:
964964
self.skipTest("requires TPU v6+")
965+
if dtype == jnp.bfloat16 and not jtu.is_cloud_tpu_at_least(2026, 3, 15):
966+
self.skipTest("Requires libtpu >= 2026.3.15")
965967

966968
@functools.partial(
967969
self.pallas_call,
@@ -1174,6 +1176,12 @@ def test_elementwise_array(self, fn, dtype):
11741176
and not jtu.is_cloud_tpu_at_least(2026, 3, 1)
11751177
):
11761178
self.skipTest("requires a newer libTPU")
1179+
if (
1180+
fn == jnp.sign
1181+
and dtype == "bfloat16"
1182+
and not jtu.is_cloud_tpu_at_least(2026, 3, 15)
1183+
):
1184+
self.skipTest("Requires libtpu >= 2026.3.15")
11771185
# TODO(b/370578663): implement these lowerings on TPU
11781186
if fn in (
11791187
jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, jnp.atanh,

tests/pallas/tpu_ops_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,8 @@ def kernel(x_ref, y_ref):
956956
jnp.bfloat16,
957957
)
958958
def test_sigmoid(self, dtype):
959+
if not jtu.is_cloud_tpu_at_least(2026, 3, 15):
960+
self.skipTest("requires a newer libTPU")
959961

960962
shape = (32, 128)
961963
x = jax.random.normal(jax.random.key(42), shape, dtype=dtype)

tests/pallas/tpu_pallas_state_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,8 @@ def kernel(x_ref, out_ref, tmp_ref):
353353
core_type=[*pltpu.CoreType], use_tc_tiling_on_sc=[True, False]
354354
)
355355
def test_capture_scalar(self, core_type, use_tc_tiling_on_sc):
356+
if not jtu.is_cloud_tpu_at_least(2026, 3, 15):
357+
self.skipTest("Requires libtpu >= 2026.3.15")
356358
match core_type:
357359
case pltpu.CoreType.TC:
358360
mesh = pltpu.create_tensorcore_mesh("x", num_cores=1)

tests/pjit_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,8 @@ def testWithCustomPRNGKey(self):
830830
pjit(lambda x: x, in_shardings=None, out_shardings=None)(key)
831831

832832
def test_lower_with_wrapper_error(self):
833+
if not jtu.is_cloud_tpu_at_least(2026, 3, 15):
834+
self.skipTest("Requires libtpu >= 2026.3.15")
833835
@jax.jit
834836
def f(x):
835837
return x
@@ -10285,6 +10287,8 @@ def f(x):
1028510287
)
1028610288
@jtu.with_explicit_mesh((2, 2), ('x', 'y'))
1028710289
def test_reduce_sum_unreduced_inp_multi_mesh(self, axes, out_s, eq_out_s, mesh):
10290+
if not jtu.is_cloud_tpu_at_least(2026, 3, 15):
10291+
self.skipTest("Requires libtpu >= 2026.3.15")
1028810292

1028910293
inp1 = jax.device_put(np.arange(16).reshape(8, 2), P('x', 'y'))
1029010294
inp2 = jax.device_put(np.arange(8).reshape(2, 4), P('y', None))
@@ -10301,6 +10305,8 @@ def f(x):
1030110305

1030210306
@jtu.with_explicit_mesh((2,), 'x')
1030310307
def test_split_reduced_concat_unreduced(self, mesh):
10308+
if not jtu.is_cloud_tpu_at_least(2026, 3, 15):
10309+
self.skipTest("Requires libtpu >= 2026.3.15")
1030410310

1030510311
x = jax.device_put(np.arange(8.).reshape(2, 4), P('x'))
1030610312
w = jax.device_put(np.arange(64.).reshape(4, 16), P(reduced={'x'}))

0 commit comments

Comments
 (0)