diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 79109ad9b7..2759422bf6 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -30,20 +30,19 @@ OR, XOR, Add, - Composite, IntDiv, Mul, ScalarMaximum, ScalarMinimum, Sub, TrueDiv, + get_scalar_type, scalar_maximum, ) from pytensor.scalar.basic import add as add_as from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad -from pytensor.tensor.type import scalar @singledispatch @@ -348,13 +347,8 @@ def axis_apply_fn(x): @numba_funcify.register(Elemwise) def numba_funcify_Elemwise(op, node, **kwargs): - # Creating a new scalar node is more involved and unnecessary - # if the scalar_op is composite, as the fgraph already contains - # all the necessary information. - scalar_node = None - if not isinstance(op.scalar_op, Composite): - scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs] - scalar_node = op.scalar_op.make_node(*scalar_inputs) + scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs] + scalar_node = op.scalar_op.make_node(*scalar_inputs) scalar_op_fn = numba_funcify( op.scalar_op, diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 0086b15a80..9a3e96c858 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -267,11 +267,11 @@ def assert_fn(x, y): x, y ) - if isinstance(fgraph, tuple): - fn_inputs, fn_outputs = fgraph - else: + if isinstance(fgraph, FunctionGraph): fn_inputs = fgraph.inputs fn_outputs = fgraph.outputs + else: + fn_inputs, fn_outputs = fgraph fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)] diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 72150b01ae..862ea1a2e2 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -15,7 +15,8 @@ from pytensor.gradient import grad from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph -from pytensor.tensor.elemwise import CAReduce, DimShuffle +from pytensor.scalar import float64 +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from tests.link.numba.test_basic import ( @@ -691,3 +692,17 @@ def test_numba_careduce_benchmark(axis, c_contiguous, benchmark): return careduce_benchmark_tester( axis, c_contiguous, mode="NUMBA", benchmark=benchmark ) + + +def test_scalar_loop(): + a = float64("a") + scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a]) + + x = pt.tensor("x", shape=(3,)) + elemwise_loop = Elemwise(scalar_loop)(3, x) + + with pytest.warns(UserWarning, match="object mode"): + compare_numba_and_py( + ([x], [elemwise_loop]), + (np.array([1, 2, 3], dtype="float64"),), + )