Skip to content

Commit 8679731

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add FuseConsecutiveRescalesPass to fuse redundant RESCALE pairs (#17830)
Summary: Add `FuseConsecutiveRescalesPass` to eliminate redundant INT32→INT8→INT32 RESCALE round-trips between chained arithmetic ops in the TOSA lowering pipeline. `InsertRescaleInt32Pass` wraps each add/sub/mul with input RESCALEs (INT8→INT32) and output RESCALEs (INT32→INT8). When ops are chained (e.g., add→add or mul→add), the output RESCALE of op1 feeds directly into the input RESCALE of op2, creating a wasteful round-trip. Each unnecessary RESCALE decomposes into Add+Mul NPU instructions (~1,130 cycles each on Ethos-U55-128), and in quantized models RESCALE overhead accounts for 25-50% of total NPU cycles. The pass detects consecutive RESCALE pairs (R1: INT32→INT8/INT16, R2: INT8/INT16→INT32) and handles two cases: - **Identity** (composed scale ≈ 1.0, matching zero points): Removes both RESCALEs and directly wires R1's input to R2's users. This eliminates the entire round-trip. Bypassing the intermediate INT8/INT16 clamp can cause up to ~120 INT8 steps of output difference, handled via `qtol=1` in tests. - **Non-identity**: Leaves the pair unchanged. Creating a single INT32→INT32 RESCALE would be semantically correct (and the TOSA ref model handles it), but Vela's NPU compiler produces all-zero outputs for INT32→INT32 RESCALE. Root cause: `EthosU55Constraints::SupportsRescale()` returns `false` for dtypes > 16 bits, causing `RewriteRescale()` to convert the RESCALE into a MUL with aggressive right-shift that zeros out values. Multi-user R1 nodes (e.g., residual connections, branching) are handled by fusing each R1→R2 pair individually while preserving R1 for non-RESCALE users. ## Context This pass runs unconditionally in the TOSA pipeline immediately after `InsertRescaleInt32Pass` (see `arm_pass_manager.py`). Identity pairs are the most common case between chained ops with similar quantization scales, so this optimization still eliminates the majority of redundant RESCALEs. The stacked diff D95243636 adds a follow-on `EliminateRescaleBeforeMulPass` that absorbs residual INT32→INT32 RESCALEs before MUL ops. ## Vela INT32→INT32 RESCALE Limitation (Follow-up) The Vela NPU compiler (Ethos-U55) cannot handle INT32→INT32 RESCALE: - `EthosU55Constraints::SupportsRescale()` rejects types > 16 bits - `GraphIrOptimiser::RewriteRescale()` decomposes rejected RESCALEs into MUL ops with explicit OFM scaling - For INT32→INT32, the MUL's right-shift (typically 20-40 bits) zeros out the result - `EliminateTosaRescale()` only handles Conv/MatMul patterns, not standalone RESCALEs - Python-side `rewrite_rescale()` has no code path for INT32→INT32 with non-Conv predecessor A follow-up Vela patch can fix `RewriteRescale()` to properly handle INT32→INT32 RESCALE (likely after the INT16 conversion step). Once Vela is fixed, this pass can be updated to also fuse non-identity pairs. ## Numerical Analysis | Source | Magnitude | Mitigation | | ------ | --------- | ---------- | | **A**: Fixed-point decomposition non-associativity | ~1 INT8 step | Handled via `qtol=1` in tests | | **B**: INT8 clamping bypass on identity removal | up to 120 INT8 steps | Bounded; handled via `qtol=1` in tests | Reviewed By: 3l1 Differential Revision: D94483331
1 parent 3604d3e commit 8679731

File tree

7 files changed

+1069
-19
lines changed

7 files changed

+1069
-19
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
QuantizeClampArgumentsPass,
103103
)
104104
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
105+
from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa
105106
from .fuse_constant_ops_pass import ( # noqa
106107
ComputeConstantOpsAOTPass,
107108
FuseConstantArgsPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
DecorateFp32toInt32CastingPass,
9999
FoldAndAnnotateQParamsPass,
100100
FuseBatchNorm2dPass,
101+
FuseConsecutiveRescalesPass,
101102
FuseConstantArgsPass,
102103
FuseDuplicateUsersPass,
103104
FuseEqualPlaceholdersPass,
@@ -183,8 +184,7 @@ def configure_skip_passes(
183184
override_config: ArmPassPipelineConfig | None = None,
184185
) -> tuple[type, ...]:
185186
"""Configures the pass manager to skip certain passes based on the
186-
ArmPassPipelineConfig class found in the compile spec.
187-
"""
187+
ArmPassPipelineConfig class found in the compile spec."""
188188
skip_set: set[type] = set()
189189

190190
config = override_config or self.compile_spec.get_pass_pipeline_config()
@@ -213,9 +213,8 @@ def validate_constraints_mandatory(self):
213213
"""Validates that necessary passes have run before transforming to
214214
backend.
215215
216-
Note that this differs from the original validate_constraints function,
217-
which only checks the order of passes.
218-
216+
Note that this differs from the original validate_constraints
217+
function, which only checks the order of passes.
219218
"""
220219
passes_to_run = defaultdict(list)
221220

@@ -245,7 +244,6 @@ def insert_passes_before(
245244
Args:
246245
target_pass_type: The pass class to insert before (e.g., InsertTableOpsPass)
247246
passes: List of pass instances to insert
248-
249247
"""
250248
self._pass_insertions.setdefault(
251249
target_pass_type, PassInsertions()
@@ -260,7 +258,6 @@ def insert_passes_after(
260258
Args:
261259
target_pass_type: The pass class to insert after
262260
passes: List of pass instances to insert
263-
264261
"""
265262
self._pass_insertions.setdefault(
266263
target_pass_type, PassInsertions()
@@ -273,7 +270,6 @@ def _apply_pass_insertions(self) -> None:
273270
274271
Raises:
275272
ValueError: If any registered target pass type is not found in the pipeline.
276-
277273
"""
278274
if self._insertions_applied or not self._pass_insertions:
279275
return
@@ -317,14 +313,13 @@ def _apply_pass_insertions(self) -> None:
317313
self._insertions_applied = True
318314

319315
def _configure_pass_insertions(self, exported_program: ExportedProgram) -> None:
320-
"""Hook for subclasses to configure pass insertions. Called at the START
321-
of pipeline construction, before any passes are added.
316+
"""Hook for subclasses to configure pass insertions. Called at the
317+
START of pipeline construction, before any passes are added.
322318
323319
Subclasses should override this to call insert_passes_before/after.
324320
325321
Args:
326322
exported_program: The exported program being transformed
327-
328323
"""
329324
pass
330325

@@ -380,6 +375,7 @@ def _tosa_pipeline(
380375
# Ticket: MLETORCH-1539
381376
DecomposeLinearPass(),
382377
InsertRescaleInt32Pass(),
378+
FuseConsecutiveRescalesPass(),
383379
InsertControlFlowRescalesPass(),
384380
DecomposeQuantNodesPass(),
385381
]
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from typing import cast, List, Set, Type
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass import ArmPass
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass
13+
from torch.fx import GraphModule, Node
14+
from torch.fx.passes.infra.pass_base import PassResult
15+
16+
logger: logging.Logger = logging.getLogger(__name__)
17+
18+
19+
class FuseConsecutiveRescalesPass(ArmPass):
20+
"""Fuse consecutive RESCALE(INT32->INT8/INT16) ->
21+
RESCALE(INT8/INT16->INT32) pairs.
22+
23+
InsertRescaleInt32Pass wraps each quantized arithmetic and comparison
24+
operator (add, sub, mul, abs, eq, ge, gt, le, lt, max, min, sum) with
25+
input rescales (INT8/INT16->INT32) and an output rescale
26+
(INT32->INT8/INT16). When two such ops are chained (e.g., add1 -> add2),
27+
the output rescale of add1 feeds directly into an input rescale of add2,
28+
creating a redundant INT32->INT8/INT16->INT32 round-trip that loses
29+
precision.
30+
31+
This pass detects such pairs and handles two cases:
32+
33+
- **Identity** (composed scale ~1.0, matching zero points): Removes both
34+
RESCALEs and directly wires R1's input to R2's users. This eliminates
35+
the entire round-trip. Bypassing the intermediate INT8/INT16 clamp can
36+
in theory cause up to ~120 INT8 steps of output difference when all
37+
inputs are near the clamp boundary; in practice, observed differences
38+
are 0-1 steps for typical distributions. Tests use qtol=1.
39+
40+
- **Non-identity**: Leaves the pair unchanged. The Vela NPU compiler
41+
cannot correctly process INT32->INT32 RESCALE (produces all-zero NPU
42+
outputs), so non-identity pairs retain their INT8/INT16 intermediate.
43+
44+
Handles multi-user R1 nodes: when R1 feeds both RESCALE and
45+
non-RESCALE users, each R1->R2 RESCALE pair is fused individually
46+
while preserving R1 for its non-RESCALE users.
47+
"""
48+
49+
_passes_required_after: Set[Type[ExportPass]] = set()
50+
51+
def call(self, graph_module: GraphModule) -> PassResult:
52+
graph = graph_module.graph
53+
modified = False
54+
nodes_to_erase: List[Node] = []
55+
rescale_before = sum(1 for n in graph.nodes if _is_rescale(n))
56+
identity_pairs_fused = 0
57+
58+
for node in list(graph.nodes):
59+
node = cast(Node, node)
60+
if not _is_fuseable_r1(node):
61+
continue
62+
63+
r1_input = node.args[0]
64+
r1_input_zp = node.args[3]
65+
r1_scale = float(node.args[2][0]) # type: ignore[arg-type]
66+
67+
node_fused = False
68+
for user in list(node.users):
69+
if _try_fuse_identity_pair(
70+
node,
71+
user,
72+
r1_input,
73+
r1_input_zp,
74+
r1_scale,
75+
nodes_to_erase,
76+
):
77+
node_fused = True
78+
identity_pairs_fused += 1
79+
80+
if node_fused:
81+
nodes_to_erase.append(node)
82+
modified = True
83+
84+
for node in nodes_to_erase:
85+
if len(node.users) == 0:
86+
graph.erase_node(node)
87+
88+
if modified:
89+
rescale_after = sum(1 for n in graph.nodes if _is_rescale(n))
90+
removed = rescale_before - rescale_after
91+
logger.info(
92+
"FuseConsecutiveRescalesPass: removed %d identity pairs "
93+
"(%d RESCALEs: %d -> %d)",
94+
identity_pairs_fused,
95+
removed,
96+
rescale_before,
97+
rescale_after,
98+
)
99+
graph_module.recompile()
100+
graph.lint()
101+
# Note: we deliberately skip super().call() — retracing is
102+
# unnecessary since this pass only rewires edges and removes
103+
# nodes without introducing new operations.
104+
105+
return PassResult(graph_module, modified)
106+
107+
108+
def _is_rescale(node: Node) -> bool:
109+
return (
110+
node.op == "call_function"
111+
and node.target == exir_ops.backend.tosa.RESCALE.default
112+
)
113+
114+
115+
def _is_fuseable_r1(node: Node) -> bool:
116+
"""Check if node is an R1 candidate.
117+
118+
R1 is RESCALE(INT32 -> INT8/INT16) with per-tensor scale.
119+
"""
120+
if not _is_rescale(node):
121+
return False
122+
if node.args[1] not in (torch.int8, torch.int16):
123+
return False
124+
if len(node.args[2]) != 1: # type: ignore[arg-type]
125+
return False
126+
r1_input = node.args[0]
127+
if isinstance(r1_input, Node) and "val" in r1_input.meta:
128+
if r1_input.meta["val"].dtype != torch.int32:
129+
return False
130+
return True
131+
132+
133+
def _try_fuse_identity_pair(
134+
r1: Node,
135+
r2: Node,
136+
r1_input: Node,
137+
r1_input_zp: int,
138+
r1_scale: float,
139+
nodes_to_erase: List[Node],
140+
) -> bool:
141+
"""Try to fuse an R1->R2 identity pair.
142+
143+
Returns True if fused.
144+
"""
145+
if not _is_rescale(r2):
146+
return False
147+
if r2.args[1] != torch.int32:
148+
return False
149+
if r1.args[4] != r2.args[3]:
150+
return False
151+
if len(r2.args[2]) != 1: # type: ignore[arg-type]
152+
return False
153+
154+
r2_scale = float(r2.args[2][0]) # type: ignore[arg-type, index]
155+
composed_scale = r1_scale * r2_scale
156+
r2_output_zp = r2.args[4]
157+
158+
if abs(composed_scale - 1.0) < 1e-6 and r1_input_zp == r2_output_zp:
159+
# Identity case: remove both RESCALEs and directly wire
160+
# R1's input (INT32) to R2's users. The composed scale
161+
# is ~1.0 so the round-trip is a no-op modulo the INT8
162+
# clamp. Bypassing the clamp can in theory cause up to
163+
# ~120 INT8 steps of difference near clamp boundaries;
164+
# observed differences are 0-1 steps. Tests use qtol=1.
165+
r2.replace_all_uses_with(r1_input)
166+
nodes_to_erase.append(r2)
167+
return True
168+
169+
# Non-identity: leave the pair unchanged. Creating a
170+
# single INT32->INT32 RESCALE with the composed scale would
171+
# be semantically correct (and the TOSA ref model handles
172+
# it), but the Vela NPU compiler produces all-zero outputs
173+
# for INT32->INT32 RESCALE operations.
174+
return False

0 commit comments

Comments
 (0)