Skip to content

Commit 88d19e7

Browse files
committed
Fix bug in numba impl of cumsum
1 parent 59caa07 commit 88d19e7

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

aesara/link/numba/dispatch/extra_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def numba_funcify_CumOp(op, node, **kwargs):
4343
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
4444

4545
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
46+
reaxis_first_inv = tuple(np.argsort(reaxis_first))
4647

4748
if mode == "add":
4849

@@ -65,7 +66,7 @@ def cumop(x):
6566
for m in range(1, x.shape[axis]):
6667
res[m] = res[m - 1] + x_axis_first[m]
6768

68-
return res.transpose(reaxis_first)
69+
return res.transpose(reaxis_first_inv)
6970

7071
else:
7172
if ndim == 1:

tests/link/numba/test_extra_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ def test_BroadcastTo(x, shape):
8080
1,
8181
"add",
8282
),
83+
(
84+
set_test_value(
85+
at.dtensor3(), np.arange(30, dtype=config.floatX).reshape((2, 3, 5))
86+
),
87+
-1,
88+
"add",
89+
),
8390
(
8491
set_test_value(
8592
at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))

0 commit comments

Comments
 (0)