Skip to content

Commit e8bd0d7

Browse files
committed
Fix bug in JAX implementation of Second
1 parent 5e6b356 commit e8bd0d7

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

pytensor/link/jax/dispatch/scalar.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ def composite(*args):
181181
@jax_funcify.register(Second)
182182
def jax_funcify_Second(op, **kwargs):
183183
def second(x, y):
184-
return jnp.broadcast_to(y, x.shape)
184+
_, y = jnp.broadcast_arrays(x, y)
185+
return y
185186

186187
return second
187188

tests/link/jax/test_scalar.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626

2727
jax = pytest.importorskip("jax")
28+
from pytensor.link.jax.dispatch import jax_funcify
2829

2930

3031
def test_second():
@@ -40,6 +41,25 @@ def test_second():
4041
fgraph = FunctionGraph([a1, b], [out])
4142
compare_jax_and_py(fgraph, [np.zeros([5], dtype=config.floatX), 5.0])
4243

44+
a2 = matrix("a2", shape=(1, None), dtype="float64")
45+
b2 = matrix("b2", shape=(None, 1), dtype="int32")
46+
out = at.second(a2, b2)
47+
fgraph = FunctionGraph([a2, b2], [out])
48+
compare_jax_and_py(
49+
fgraph, [np.zeros((1, 3), dtype="float64"), np.ones((5, 1), dtype="int32")]
50+
)
51+
52+
53+
def test_second_constant_scalar():
54+
b = scalar("b", dtype="int")
55+
out = at.second(0.0, b)
56+
fgraph = FunctionGraph([b], [out])
57+
# Test dispatch directly as useless second is removed during compilation
58+
fn = jax_funcify(fgraph)
59+
[res] = fn(1)
60+
assert res == 1
61+
assert res.dtype == out.dtype
62+
4363

4464
def test_identity():
4565
a = scalar("a")

0 commit comments

Comments
 (0)