diff --git a/pytensor/link/jax/dispatch/scalar.py b/pytensor/link/jax/dispatch/scalar.py index 5fb1106223..1790da3238 100644 --- a/pytensor/link/jax/dispatch/scalar.py +++ b/pytensor/link/jax/dispatch/scalar.py @@ -63,11 +63,18 @@ def clip(x, min, max): @jax_funcify.register(Composite) -def jax_funcify_Composite(op, vectorize=True, **kwargs): +def jax_funcify_Composite(op, node, vectorize=True, **kwargs): jax_impl = jax_funcify(op.fgraph) - def composite(*args): - return jax_impl(*args)[0] + if len(node.outputs) == 1: + + def composite(*args): + return jax_impl(*args)[0] + + else: + + def composite(*args): + return jax_impl(*args) return jnp.vectorize(composite) diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index 396bb61076..2a94cc1b21 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -63,7 +63,7 @@ def test_identity(): ), ], ) -def test_jax_Composite(x, y, x_val, y_val): +def test_jax_Composite_singe_output(x, y, x_val, y_val): x_s = aes.float64("x") y_s = aes.float64("y") @@ -80,6 +80,16 @@ def test_jax_Composite(x, y, x_val, y_val): _ = compare_jax_and_py(out_fg, test_input_vals) +def test_jax_Composite_multi_output(): + x = vector("x") + + x_s = aes.float64("xs") + outs = Elemwise(Composite(inputs=[x_s], outputs=[x_s + 1, x_s - 1]))(x) + + fgraph = FunctionGraph([x], outs) + compare_jax_and_py(fgraph, [np.arange(10, dtype=config.floatX)]) + + def test_erf(): x = scalar("x") out = erf(x)