Skip to content

Commit e4eeb98

Browse files
authored
Arm backend: normalize inplace arithmetic ops. (#18231)
Just turn them into normal out-of-place ops. This solves some inplace scalar xfails.
1 parent 75c85e7 commit e4eeb98

File tree

6 files changed

+45
-23
lines changed

6 files changed

+45
-23
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@
146146
RewriteHighRankSingletonPermutePass,
147147
)
148148
from .rewrite_index_put_pass import RewriteIndexPutPass # noqa
149+
from .rewrite_inplace_arithmetic_pass import RewriteInplaceArithmeticPass # noqa
149150
from .rewrite_le_lt_to_ge_gt_pass import RewriteLeLtToGeGtPass # noqa
150151
from .rewrite_matmul import RewriteMatmulPass # noqa
151152
from .rewrite_pad import RewritePadPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
RewriteConvPass,
126126
RewriteHighRankSingletonPermutePass,
127127
RewriteIndexPutPass,
128+
RewriteInplaceArithmeticPass,
128129
RewriteLeLtToGeGtPass,
129130
RewriteMatmulPass,
130131
RewritePadPass,
@@ -553,6 +554,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
553554
DecomposeFloorDividePass(tfa_pass=True),
554555
DecomposeDivTensorModePass(tfa_pass=True),
555556
DecomposeWhereScalarOtherPass(tfa_pass=True),
557+
RewriteInplaceArithmeticPass(tfa_pass=True),
556558
]
557559
)
558560

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
10+
from executorch.backends.arm._passes import ArmPass
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass
13+
14+
OP_MAP = {
15+
torch.ops.aten.add_.Tensor: torch.ops.aten.add.Tensor,
16+
torch.ops.aten.sub_.Tensor: torch.ops.aten.sub.Tensor,
17+
torch.ops.aten.mul_.Tensor: torch.ops.aten.mul.Tensor,
18+
torch.ops.aten.div_.Tensor: torch.ops.aten.div.Tensor,
19+
exir_ops.edge.aten.add_.Tensor: exir_ops.edge.aten.add.Tensor,
20+
exir_ops.edge.aten.sub_.Tensor: exir_ops.edge.aten.sub.Tensor,
21+
exir_ops.edge.aten.mul_.Tensor: exir_ops.edge.aten.mul.Tensor,
22+
exir_ops.edge.aten.div_.Tensor: exir_ops.edge.aten.div.Tensor,
23+
}
24+
25+
26+
class RewriteInplaceArithmeticPass(ArmPass):
27+
"""Rewrite inplace arithmetic ops into functional equivalents."""
28+
29+
_passes_required_after: Set[Type[ExportPass]] = set()
30+
31+
def call_operator(self, op, args, kwargs, meta):
32+
if not self.allowed_to_transform(meta):
33+
return super().call_operator(op, args, kwargs, meta)
34+
35+
target_op = OP_MAP.get(op)
36+
if target_op is None:
37+
return super().call_operator(op, args, kwargs, meta)
38+
39+
return super().call_operator(target_op, args, kwargs, meta, updated=True)

backends/arm/_passes/scalars_to_attribute_pass.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,10 @@ class ScalarsToAttributePass(ArmPass):
2727

2828
targeted_ops = [
2929
torch.ops.aten.add.Tensor,
30-
torch.ops.aten.add_.Tensor,
3130
torch.ops.aten.sub.Tensor,
32-
torch.ops.aten.sub_.Tensor,
3331
torch.ops.aten.rsub.Scalar,
3432
torch.ops.aten.mul.Tensor,
35-
torch.ops.aten.mul_.Tensor,
3633
torch.ops.aten.div.Tensor,
37-
torch.ops.aten.div_.Tensor,
3834
]
3935

4036
def _convert_scalar_args(

backends/arm/quantizer/quantization_annotator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -703,13 +703,10 @@ def any_or_hardtanh_min_zero(n: Node):
703703
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
704704
elif node.target in (
705705
torch.ops.aten.add.Tensor,
706-
torch.ops.aten.add_.Tensor,
707706
torch.ops.aten.sub.Tensor,
708-
torch.ops.aten.sub_.Tensor,
709707
torch.ops.aten.mm.default,
710708
torch.ops.aten.bmm.default,
711709
torch.ops.aten.mul.Tensor,
712-
torch.ops.aten.mul_.Tensor,
713710
):
714711
quant_properties.quant_inputs = [
715712
_QuantProperty(0, input_act_qspec),

backends/arm/test/ops/test_scalars.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,6 @@ def forward(self, x):
168168
"float_r4_st": "MLETORCH-408: Arithmetic ops can't handle scalars first",
169169
}
170170

171-
int_inplace_xfails = {
172-
"int_r1_ts": "MLETORCH-1708: Numerical error in TFA/quantization",
173-
"int_r4_ts": "MLETORCH-1708: Numerical error in TFA/quantization",
174-
"float_r1_ts": "MLETORCH-1708: Numerical error in TFA/quantization",
175-
"float_r4_ts": "MLETORCH-1708: Numerical error in TFA/quantization",
176-
}
177-
178171

179172
# ADD FP ------------------------------------------------------
180173
@common.parametrize("test_data", tensor_scalar_tests, xfails=xfails)
@@ -215,9 +208,7 @@ def test_add_tensor_tosa_INT_scalar(test_data):
215208
pipeline.run()
216209

217210

218-
@common.parametrize(
219-
"test_data", tensor_scalar_tests, xfails=int_inplace_xfails, strict=False
220-
)
211+
@common.parametrize("test_data", tensor_scalar_tests)
221212
def test_add_tensor_tosa_INT_inplace(test_data):
222213
"""Tests inplace add with one scalar input."""
223214
pipeline = TosaPipelineINT[input_t1](AddInplace(), test_data, aten_op=[])
@@ -285,9 +276,7 @@ def test_sub_tensor_tosa_INT_scalar(test_data):
285276
pipeline.run()
286277

287278

288-
@common.parametrize(
289-
"test_data", tensor_scalar_tests, xfails=int_inplace_xfails, strict=False
290-
)
279+
@common.parametrize("test_data", tensor_scalar_tests)
291280
def test_sub_tensor_tosa_INT_inplace(test_data):
292281
"""Tests inplace sub with one scalar input."""
293282
pipeline = TosaPipelineINT[input_t1](SubInplace(), test_data, aten_op=[])
@@ -344,9 +333,7 @@ def test_mul_tensor_tosa_INT_scalar(test_data):
344333
pipeline.run()
345334

346335

347-
@common.parametrize(
348-
"test_data", tensor_scalar_tests, xfails=int_inplace_xfails, strict=False
349-
)
336+
@common.parametrize("test_data", tensor_scalar_tests)
350337
def test_mul_tensor_tosa_INT_inplace(test_data):
351338
"""Tests inplace mul with one scalar input."""
352339
pipeline = TosaPipelineINT[input_t1](MulInplace(), test_data, aten_op=[])

0 commit comments

Comments
 (0)