We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent fc44784 commit f298599Copy full SHA for f298599
tests/link/numba/test_elemwise.py
@@ -529,3 +529,27 @@ def test_elemwise_out_type():
529
x_val = np.broadcast_to(np.zeros((3,)), (6, 3))
530
531
assert func(x_val).shape == (18,)
532
+
533
534
+@pytest.mark.parametrize("axis", [0, 2, (0, 2), None])
535
+@pytest.mark.parametrize("op", [Sum, Max, Any])
536
+def test_careduce_benchmark(benchmark, op, axis):
537
+ rng = np.random.default_rng(123)
538
+ N = 256
539
+ if op == All:
540
+ # Sparse tensor
541
+ value = np.zeros((N, N, N), dtype="bool")
542
+ true_arrays = np.random.choice(N, size=N // 2, replace=False)
543
+ true_rows = np.random.choice(N, size=N // 2, replace=False)
544
+ true_cols = np.random.choice(N, size=N // 2, replace=False)
545
+ value[true_arrays, true_rows, true_cols] = True
546
+ else:
547
+ value = rng.normal(size=(N, N, N))
548
549
+ x = pytensor.shared(value, name="x")
550
+ out = op(axis=axis)(x)
551
552
+ func = pytensor.function([], [out], mode="NUMBA")
553
+ # JIT compile first
554
+ func()
555
+ benchmark(func)
0 commit comments