Skip to content

Commit f298599

Browse files
committed
Benchmark reductions in Numba
1 parent fc44784 commit f298599

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tests/link/numba/test_elemwise.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,3 +529,27 @@ def test_elemwise_out_type():
529529
x_val = np.broadcast_to(np.zeros((3,)), (6, 3))
530530

531531
assert func(x_val).shape == (18,)
532+
533+
534+
@pytest.mark.parametrize("axis", [0, 2, (0, 2), None])
535+
@pytest.mark.parametrize("op", [Sum, Max, Any])
536+
def test_careduce_benchmark(benchmark, op, axis):
537+
rng = np.random.default_rng(123)
538+
N = 256
539+
if op == All:
540+
# Sparse tensor
541+
value = np.zeros((N, N, N), dtype="bool")
542+
true_arrays = np.random.choice(N, size=N // 2, replace=False)
543+
true_rows = np.random.choice(N, size=N // 2, replace=False)
544+
true_cols = np.random.choice(N, size=N // 2, replace=False)
545+
value[true_arrays, true_rows, true_cols] = True
546+
else:
547+
value = rng.normal(size=(N, N, N))
548+
549+
x = pytensor.shared(value, name="x")
550+
out = op(axis=axis)(x)
551+
552+
func = pytensor.function([], [out], mode="NUMBA")
553+
# JIT compile first
554+
func()
555+
benchmark(func)

0 commit comments

Comments
 (0)