Skip to content

Commit c0045c3

Browse files
committed
Avoid copy of zeros in AdvancedIncSubtensor1
1 parent be43535 commit c0045c3

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

pytensor/tensor/rewriting/subtensor.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -1295,12 +1295,28 @@ def local_inplace_setsubtensor(fgraph, node):
12951295

12961296
@node_rewriter([AdvancedIncSubtensor1], inplace=True)
12971297
def local_inplace_AdvancedIncSubtensor1(fgraph, node):
1298-
if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace:
1299-
new_op = node.op.clone_inplace()
1300-
new_node = new_op(*node.inputs)
1301-
copy_stack_trace(node.outputs, new_node)
1302-
return [new_node]
1303-
return False
1298+
if node.op.inplace:
1299+
return
1300+
1301+
x, y, idx = node.inputs
1302+
if fgraph.has_destroyers([x]):
1303+
# In this case we can't operate inplace, but if x is just an alloc of zeros
1304+
# We're better off duplicating it and then acting on it inplace.
1305+
if (
1306+
x.owner is not None
1307+
and isinstance(x.owner.op, Alloc)
1308+
and all(x.owner.inputs[0].type.broadcastable)
1309+
and isinstance(x.owner.inputs[0], Constant)
1310+
and x.owner.inputs[0].unique_value == 0
1311+
):
1312+
x = x.owner.clone().outputs[0]
1313+
else:
1314+
return None # Inplace isn't valid
1315+
1316+
new_op = node.op.clone_inplace()
1317+
new_node = new_op(x, y, idx)
1318+
copy_stack_trace(node.outputs, new_node)
1319+
return [new_node]
13041320

13051321

13061322
compile.optdb.register(

0 commit comments

Comments
 (0)