Skip to content

Commit cc9d664

Browse files
committed
Implement shape inference for boolean advanced indexing
1 parent 8c113f1 commit cc9d664

File tree

2 files changed

+105
-22
lines changed

2 files changed

+105
-22
lines changed

pytensor/tensor/subtensor.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,11 @@
2020
from pytensor.printing import Printer, pprint, set_precedence
2121
from pytensor.scalar.basic import ScalarConstant
2222
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
23-
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value
23+
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero
2424
from pytensor.tensor.elemwise import DimShuffle
25-
from pytensor.tensor.exceptions import (
26-
AdvancedIndexingError,
27-
NotScalarConstantError,
28-
ShapeError,
29-
)
25+
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
3026
from pytensor.tensor.math import clip
31-
from pytensor.tensor.shape import Reshape, specify_broadcastable
27+
from pytensor.tensor.shape import Reshape, shape_i, specify_broadcastable
3228
from pytensor.tensor.type import (
3329
TensorType,
3430
bscalar,
@@ -2585,26 +2581,47 @@ def R_op(self, inputs, eval_points):
25852581
return self.make_node(eval_points[0], *inputs[1:]).outputs
25862582

25872583
def infer_shape(self, fgraph, node, ishapes):
2588-
indices = node.inputs[1:]
2589-
index_shapes = list(ishapes[1:])
2590-
for i, idx in enumerate(indices):
2591-
if (
2584+
def is_bool_index(idx):
2585+
return (
25922586
isinstance(idx, (np.bool_, bool))
25932587
or getattr(idx, "dtype", None) == "bool"
2594-
):
2595-
raise ShapeError(
2596-
"Shape inference for boolean indices is not implemented"
2588+
)
2589+
2590+
indices = node.inputs[1:]
2591+
index_shapes = []
2592+
for idx, ishape in zip(indices, ishapes[1:]):
2593+
# Mixed bool indexes are converted to nonzero entries
2594+
if is_bool_index(idx):
2595+
index_shapes.extend(
2596+
(shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx)
25972597
)
25982598
# The `ishapes` entries for `SliceType`s will be None, and
25992599
# we need to give `indexed_result_shape` the actual slices.
2600-
if isinstance(getattr(idx, "type", None), SliceType):
2601-
index_shapes[i] = idx
2600+
elif isinstance(getattr(idx, "type", None), SliceType):
2601+
index_shapes.append(idx)
2602+
else:
2603+
index_shapes.append(ishape)
26022604

2603-
res_shape = indexed_result_shape(
2604-
ishapes[0], index_shapes, indices_are_shapes=True
2605+
res_shape = list(
2606+
indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
26052607
)
2608+
2609+
adv_indices = [idx for idx in indices if not is_basic_idx(idx)]
2610+
bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]
2611+
2612+
# Special logic when the only advanced index group is of bool type.
2613+
# We can replace the nonzeros by a sum of the whole bool variable.
2614+
if len(bool_indices) == 1 and len(adv_indices) == 1:
2615+
[bool_index] = bool_indices
2616+
# Find the output dim associated with the bool index group
2617+
# Because there are no more advanced index groups, there is exactly
2618+
# one output dim per index variable up to the bool group.
2619+
# Note: Scalar integer indexing counts as advanced indexing.
2620+
start_dim = indices.index(bool_index)
2621+
res_shape[start_dim] = bool_index.sum()
2622+
26062623
assert node.outputs[0].ndim == len(res_shape)
2607-
return [list(res_shape)]
2624+
return [res_shape]
26082625

26092626
def perform(self, node, inputs, out_):
26102627
(out,) = out_

tests/tensor/test_subtensor.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
tensor,
6464
tensor3,
6565
tensor4,
66+
tensor5,
6667
vector,
6768
)
6869
from pytensor.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype
@@ -2150,6 +2151,12 @@ def fun(x, y):
21502151

21512152

21522153
class TestInferShape(utt.InferShapeTester):
2154+
@staticmethod
2155+
def random_bool_mask(shape, rng=None):
2156+
if rng is None:
2157+
rng = np.random.default_rng()
2158+
return rng.binomial(n=1, p=0.5, size=shape).astype(bool)
2159+
21532160
def test_IncSubtensor(self):
21542161
admat = dmatrix()
21552162
bdmat = dmatrix()
@@ -2439,25 +2446,84 @@ def test_AdvancedSubtensor_bool(self):
24392446
n = dmatrix()
24402447
n_val = np.arange(6).reshape((2, 3))
24412448

2442-
# infer_shape is not implemented, but it should not crash
24432449
self._compile_and_check(
24442450
[n],
24452451
[n[n[:, 0] > 2, n[0, :] > 2]],
24462452
[n_val],
24472453
AdvancedSubtensor,
2448-
check_topo=False,
24492454
)
24502455
self._compile_and_check(
24512456
[n],
24522457
[n[n[:, 0] > 2]],
24532458
[n_val],
24542459
AdvancedSubtensor,
2455-
check_topo=False,
2460+
)
2461+
self._compile_and_check(
2462+
[n],
2463+
[n[:, np.array([True, False, True])]],
2464+
[n_val],
2465+
AdvancedSubtensor,
2466+
)
2467+
self._compile_and_check(
2468+
[n],
2469+
[n[np.array([False, False]), 1:]],
2470+
[n_val],
2471+
AdvancedSubtensor,
2472+
)
2473+
self._compile_and_check(
2474+
[n],
2475+
[n[np.array([True, True]), 0]],
2476+
[n_val],
2477+
AdvancedSubtensor,
2478+
)
2479+
self._compile_and_check(
2480+
[n],
2481+
[n[self.random_bool_mask(n_val.shape)]],
2482+
[n_val],
2483+
AdvancedSubtensor,
2484+
)
2485+
self._compile_and_check(
2486+
[n],
2487+
[n[None, self.random_bool_mask(n_val.shape), None]],
2488+
[n_val],
2489+
AdvancedSubtensor,
2490+
)
2491+
self._compile_and_check(
2492+
[n],
2493+
[n[slice(5, None), self.random_bool_mask(n_val.shape[1])]],
2494+
[n_val],
2495+
AdvancedSubtensor,
24562496
)
24572497

24582498
abs_res = n[~isinf(n)]
24592499
assert abs_res.type.shape == (None,)
24602500

2501+
def test_AdvancedSubtensor_bool_mixed(self):
2502+
n = tensor5("x", dtype="float64")
2503+
shape = (18, 3, 4, 5, 6)
2504+
n_val = np.arange(np.prod(shape)).reshape(shape)
2505+
self._compile_and_check(
2506+
[n],
2507+
# Consecutive advanced index
2508+
[n[1:, self.random_bool_mask((3, 4)), 0, 1:]],
2509+
[n_val],
2510+
AdvancedSubtensor,
2511+
)
2512+
self._compile_and_check(
2513+
[n],
2514+
# Non-consecutive advanced index
2515+
[n[1:, self.random_bool_mask((3, 4)), 1:, 0]],
2516+
[n_val],
2517+
AdvancedSubtensor,
2518+
)
2519+
self._compile_and_check(
2520+
[n],
2521+
# Non-consecutive advanced index
2522+
[n[1:, self.random_bool_mask((3,)), 1:, None, np.zeros((6, 1), dtype=int)]],
2523+
[n_val],
2524+
AdvancedSubtensor,
2525+
)
2526+
24612527

24622528
@config.change_flags(compute_test_value="raise")
24632529
def test_basic_shape():

0 commit comments

Comments
 (0)