Skip to content

Commit 786b141

Browse files
Remove strict TensorType.broadcastable usage from local_elemwise_alloc
1 parent fd50f36 commit 786b141

File tree

2 files changed

+56
-91
lines changed

2 files changed

+56
-91
lines changed

aesara/tensor/basic_opt.py

Lines changed: 51 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,13 @@
6868
)
6969
from aesara.tensor.elemwise import DimShuffle, Elemwise
7070
from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
71-
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, broadcast_shape
71+
from aesara.tensor.extra_ops import (
72+
BroadcastTo,
73+
Repeat,
74+
Unique,
75+
broadcast_shape,
76+
broadcast_to,
77+
)
7278
from aesara.tensor.math import all as at_all
7379
from aesara.tensor.math import eq
7480
from aesara.tensor.shape import (
@@ -1491,26 +1497,11 @@ def local_elemwise_alloc(fgraph, node):
14911497
introduces them as a canonicalization of `Alloc`'s with leading
14921498
broadcastable dimensions.
14931499
"""
1494-
if not isinstance(node.op, Elemwise):
1495-
return False
1496-
14971500
# Rewrite is only applicable when there are at least two inputs
14981501
if len(node.inputs) == 1:
1499-
return None
1502+
return False
15001503

15011504
if len(node.outputs) > 1:
1502-
# Ensure all outputs have the same broadcast pattern
1503-
# This is a supposition that I'm not sure is always true.
1504-
assert all(
1505-
o.type.broadcastable == node.outputs[0].type.broadcastable
1506-
for o in node.outputs[1:]
1507-
)
1508-
1509-
# The broadcast pattern of the output must match the broadcast
1510-
# pattern of at least one of the inputs.
1511-
if not any(
1512-
i.type.broadcastable == node.outputs[0].type.broadcastable for i in node.inputs
1513-
):
15141505
return False
15151506

15161507
def dimshuffled_alloc(i):
@@ -1523,103 +1514,74 @@ def dimshuffled_alloc(i):
15231514
# At least one input must have an owner that is either a `Alloc` or a
15241515
# `DimShuffle` with an owner that is a `Alloc` -- otherwise there is
15251516
# nothing to optimize.
1526-
if not any(
1527-
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
1528-
for i in node.inputs
1529-
):
1517+
alloc_idxs = [
1518+
idx
1519+
for idx, i in enumerate(node.inputs)
1520+
if i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
1521+
]
1522+
if len(alloc_idxs) == 0:
15301523
return False
15311524

15321525
# Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
15331526
# baseline for the dimensions.
1534-
assert_op_idx = None
1527+
ref_var_idx = None
15351528
for idx, i in enumerate(node.inputs):
15361529
if i.type.broadcastable == node.outputs[0].type.broadcastable:
15371530
# Prefer an input that is not a `Alloc` nor a `DimShuffle` of a
15381531
# `Alloc` so that all `Alloc`s can be optimized.
1539-
if not (
1540-
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
1541-
):
1542-
assert_op_idx = idx
1532+
if idx not in alloc_idxs:
1533+
ref_var_idx = idx
15431534
break
15441535

15451536
# If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
1546-
if assert_op_idx is None:
1537+
if ref_var_idx is None:
15471538
for idx, i in enumerate(node.inputs):
1548-
if (i.type.broadcastable == node.outputs[0].type.broadcastable) and (
1549-
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
1550-
):
1551-
assert_op_idx = idx
1539+
# XXX: This broadcastable comparison doesn't work
1540+
if (
1541+
i.type.broadcastable == node.outputs[0].type.broadcastable
1542+
) and idx in alloc_idxs:
1543+
ref_var_idx = idx
15521544
break
15531545

1554-
assert_op_in = node.inputs[assert_op_idx]
1555-
cmp_op = assert_op_in
1556-
new_i = []
1557-
same_shape = fgraph.shape_feature.same_shape
1558-
for i in node.inputs:
1546+
if not hasattr(fgraph, "shape_feature"):
1547+
return False
1548+
1549+
input_shapes = [
1550+
tuple(fgraph.shape_feature.get_shape(i, j) for j in range(i.type.ndim))
1551+
for i in node.inputs
1552+
]
1553+
bcasted_shape = broadcast_shape(
1554+
*input_shapes,
1555+
arrays_are_shapes=True,
1556+
)
1557+
1558+
new_inputs = list(node.inputs)
1559+
for idx in alloc_idxs:
1560+
i = node.inputs[idx]
1561+
15591562
# Remove `Alloc`
1560-
if i.owner and isinstance(i.owner.op, Alloc):
1561-
assert i.type.ndim == cmp_op.ndim
1562-
if config.experimental__local_alloc_elemwise_assert:
1563-
get_shape = fgraph.shape_feature.get_shape
1564-
cond = []
1565-
for idx in range(i.type.ndim):
1566-
if not i.type.broadcastable[idx] and not same_shape(
1567-
i, cmp_op, idx, idx
1568-
):
1569-
i_shp = get_shape(i, idx)
1570-
cmp_shp = get_shape(cmp_op, idx)
1571-
cond.append(eq(i_shp, cmp_shp))
1572-
if cond:
1573-
assert_op_in = assert_op(assert_op_in, *cond)
1574-
alloc_input = i.owner.inputs[0]
1575-
if alloc_input.ndim != i.ndim:
1576-
# The `Alloc` can add dimensions to the value.
1577-
# We replace those cases with a `DimShuffle` here.
1578-
nb_dim_to_add = i.ndim - alloc_input.ndim
1579-
alloc_input = alloc_input.dimshuffle(
1580-
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
1581-
)
1582-
copy_stack_trace(i, alloc_input)
1583-
new_i.append(alloc_input)
1563+
if isinstance(i.owner.op, Alloc):
1564+
new_alloc = broadcast_to(i.owner.inputs[0], bcasted_shape)
15841565

1566+
# TODO FIXME: This shouldn't be handled here.
1567+
# `DimShuffle`s should be lifted through `Alloc`s
1568+
# by other, more general rewrites.
15851569
# Remove `Alloc` in `DimShuffle`
1586-
elif i.owner and dimshuffled_alloc(i):
1587-
assert i.type.ndim == cmp_op.type.ndim
1588-
if config.experimental__local_alloc_elemwise_assert:
1589-
assert_cond = [
1590-
eq(i.shape[idx], cmp_op.shape[idx])
1591-
for idx in range(i.type.ndim)
1592-
if not i.type.broadcastable[idx]
1593-
and not same_shape(i, cmp_op, idx, idx)
1594-
]
1595-
if assert_cond:
1596-
assert_op_in = assert_op(assert_op_in, *assert_cond)
1597-
alloc_input = i.owner.inputs[0].owner.inputs[0]
1598-
if alloc_input.ndim != i.owner.inputs[0].ndim:
1599-
# The `Alloc` can add dimensions to the value.
1600-
# We replace those cases with a `DimShuffle` here.
1601-
# We let later optimizations merge the nested `DimShuffle`s
1602-
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
1603-
alloc_input = alloc_input.dimshuffle(
1604-
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
1605-
)
1606-
1570+
elif isinstance(i.owner.op, DimShuffle):
1571+
new_alloc = i.owner.inputs[0].owner.inputs[0]
16071572
# We need to keep the old `DimShuffle`. It could swap axes or
16081573
# add dimensions anywhere.
1609-
r_i = i.owner.op(alloc_input)
1610-
copy_stack_trace(i, r_i)
1611-
new_i.append(r_i)
1574+
new_alloc = broadcast_to(i.owner.op(new_alloc), bcasted_shape)
16121575

1613-
else:
1614-
new_i.append(i)
1615-
new_i[assert_op_idx] = assert_op_in
1576+
copy_stack_trace(i, new_alloc)
1577+
new_inputs[idx] = new_alloc
16161578

16171579
# If this assert is triggered, it means we are recreating an equivalent graph
16181580
# which would result in a cyclical merge optimization.
1619-
if all(new is old for new, old in zip(new_i, node.inputs)):
1581+
if all(new is old for new, old in zip(new_inputs, node.inputs)):
16201582
return
16211583

1622-
ret = node.op(*new_i, return_list=True)
1584+
ret = node.op(*new_inputs, return_list=True)
16231585
copy_stack_trace(node.outputs, ret)
16241586
return ret
16251587

tests/tensor/test_basic_opt.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3572,6 +3572,9 @@ def test_Shape_i_canonicalize():
35723572
@pytest.mark.parametrize(
35733573
"expr, x_shape, y_shape",
35743574
[
3575+
(lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 2), (3, 2)),
3576+
(lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 1), (1, 1)),
3577+
(lambda x, y: at.mul(x, at.alloc(y, 2, 3)), (1, 3), (2, 3)),
35753578
pytest.param(
35763579
lambda x, y: at.mul(y, at.alloc(1, x)),
35773580
(),
@@ -3592,8 +3595,8 @@ def test_Shape_i_canonicalize():
35923595
],
35933596
)
35943597
def test_local_elemwise_alloc(expr, x_shape, y_shape):
3595-
x = at.tensor("int64", (False,) * len(x_shape))
3596-
y = at.tensor("int64", (False,) * len(y_shape))
3598+
x = at.tensor("int64", (False,) * len(x_shape), name="x")
3599+
y = at.tensor("int64", (False,) * len(y_shape), name="y")
35973600
z = expr(x, y)
35983601

35993602
z_opt = aesara.function(

0 commit comments

Comments
 (0)