@@ -1295,12 +1295,28 @@ def local_inplace_setsubtensor(fgraph, node):
1295
1295
1296
1296
@node_rewriter ([AdvancedIncSubtensor1 ], inplace = True )
1297
1297
def local_inplace_AdvancedIncSubtensor1 (fgraph , node ):
1298
- if isinstance (node .op , AdvancedIncSubtensor1 ) and not node .op .inplace :
1299
- new_op = node .op .clone_inplace ()
1300
- new_node = new_op (* node .inputs )
1301
- copy_stack_trace (node .outputs , new_node )
1302
- return [new_node ]
1303
- return False
1298
+ if node .op .inplace :
1299
+ return
1300
+
1301
+ x , y , idx = node .inputs
1302
+ if fgraph .has_destroyers ([x ]):
1303
+ # In this case we can't operate inplace, but if x is just an alloc of zeros
1304
+ # We're better off duplicating it and then acting on it inplace.
1305
+ if (
1306
+ x .owner is not None
1307
+ and isinstance (x .owner .op , Alloc )
1308
+ and all (x .owner .inputs [0 ].type .broadcastable )
1309
+ and isinstance (x .owner .inputs [0 ], Constant )
1310
+ and x .owner .inputs [0 ].unique_value == 0
1311
+ ):
1312
+ x = x .owner .clone ().outputs [0 ]
1313
+ else :
1314
+ return None # Inplace isn't valid
1315
+
1316
+ new_op = node .op .clone_inplace ()
1317
+ new_node = new_op (x , y , idx )
1318
+ copy_stack_trace (node .outputs , new_node )
1319
+ return [new_node ]
1304
1320
1305
1321
1306
1322
compile .optdb .register (
0 commit comments