Skip to content

Commit eb2431d

Browse files
committed
Add rewrite to lift SpecifyShape through Elemwise Operations
1 parent 7775b9e commit eb2431d

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

pytensor/tensor/rewriting/shape.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,52 @@ def local_Shape_of_SpecifyShape(fgraph, node):
10261026
return [stack(shape).astype(np.int64)]
10271027

10281028

1029+
@register_canonicalize
1030+
@register_specialize
1031+
@node_rewriter([SpecifyShape])
1032+
def local_specify_shape_lift(fgraph, node):
1033+
"""Lift SpecifyShape of Elemwise towards the inputs."""
1034+
inp, *shape = node.inputs
1035+
if inp.owner and isinstance(inp.owner.op, Elemwise):
1036+
if len(inp.owner.outputs) != 1:
1037+
return None
1038+
1039+
elem_inps = inp.owner.inputs
1040+
if len(elem_inps) == 1:
1041+
new_elem_inps = [specify_shape(elem_inps[0], shape)]
1042+
else:
1043+
# Rewrite does not support case where specify_shape provides new broadcastable information,
1044+
# As that may require a specify_shape for each input
1045+
out_broadcastable = node.outputs[0].type.broadcastable
1046+
if out_broadcastable != inp.type.broadcastable:
1047+
return None
1048+
1049+
# All non-broadcastable dimensions of inputs must match the non-broadcastbale specify_shape dims
1050+
# We look for a sufficient input to assign all the specify_shape dims
1051+
# We could consider distributing the SpecifyShape across multiple inputs, when none is sufficient
1052+
1053+
nonbcast_dims = {
1054+
i
1055+
for i, (dim, bcast) in enumerate(zip(shape, out_broadcastable))
1056+
if (not bcast and not NoneConst.equals(dim))
1057+
}
1058+
new_elem_inps = elem_inps.copy()
1059+
for i, elem_inp in enumerate(elem_inps):
1060+
if all(
1061+
bcast_dim is False
1062+
for dim, bcast_dim in enumerate(elem_inp.type.broadcastable)
1063+
if dim in nonbcast_dims
1064+
):
1065+
new_elem_inps[i] = specify_shape(elem_inp, shape)
1066+
break
1067+
else: # no-break, no sufficient candidate found
1068+
return None
1069+
1070+
new_out = inp.owner.op.make_node(*new_elem_inps).outputs
1071+
copy_stack_trace(node.outputs, new_out)
1072+
return new_out
1073+
1074+
10291075
@register_useless
10301076
@register_canonicalize
10311077
@node_rewriter([Shape_i])

tests/tensor/rewriting/test_shape.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,14 @@ def test_local_Shape_of_SpecifyShape_partial(s1):
491491
assert not any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
492492

493493

494+
def test_local_specify_shape_lift():
495+
x = vector("x")
496+
out = specify_shape([1.0] + x, shape=(5,))
497+
498+
new_out = rewrite_graph(out)
499+
assert equal_computations([new_out], [[1.0] + specify_shape(x, shape=(5,))])
500+
501+
494502
def test_local_Shape_i_ground():
495503
x = tensor(dtype=np.float64, shape=(None, 2))
496504
s = Shape_i(1)(x)

0 commit comments

Comments
 (0)