Skip to content

Commit ac11da6

Browse files
committed
Numba DimShuffle: special case for 0d input
This circumvents a bug when DimShuffle of a scalar shows up inside a Blockwise, as the outer indexing yields a float (as opposed to a numpy scalar) which has no `.shape` attribute.
1 parent 31304be commit ac11da6

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,16 @@ def squeeze_to_0d(x):
466466

467467
return squeeze_to_0d
468468

469+
elif op.input_ndim == 0:
470+
# DimShuffle can only be an expand_dims or a no_op
471+
# This branch uses asarray in case we get a scalar due to https://github.com/numba/numba/issues/10358
472+
new_shape = shape_template
473+
new_strides = strides_template
474+
475+
@numba_basic.numba_njit
476+
def dimshuffle(x):
477+
return as_strided(np.asarray(x), shape=new_shape, strides=new_strides)
478+
469479
else:
470480

471481
@numba_basic.numba_njit
@@ -490,7 +500,7 @@ def dimshuffle(x):
490500

491501
return as_strided(x, shape=new_shape, strides=new_strides)
492502

493-
cache_version = 1
503+
cache_version = 2
494504
return dimshuffle, cache_version
495505

496506

tests/link/numba/test_blockwise.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor.tensor import lvector, tensor, tensor3
66
from pytensor.tensor.basic import Alloc, ARange, constant
77
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
8+
from pytensor.tensor.elemwise import DimShuffle
89
from pytensor.tensor.nlinalg import SVD, Det
910
from pytensor.tensor.slinalg import Cholesky, cholesky
1011
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
@@ -80,3 +81,12 @@ def test_blockwise_alloc():
8081
assert out.type.ndim == 3
8182

8283
compare_numba_and_py([val], [out], [np.arange(5)], eval_obj_mode=False)
84+
85+
86+
def test_blockwise_scalar_dimshuffle():
87+
x = lvector("x")
88+
blockwise_scalar_ds = Blockwise(
89+
DimShuffle(input_ndim=0, new_order=["x", "x"]), signature="()->(1,1)"
90+
)
91+
out = blockwise_scalar_ds(x)
92+
compare_numba_and_py([x], [out], [np.arange(9)], eval_obj_mode=False)

0 commit comments

Comments
 (0)