Skip to content

Commit f1f7398

Browse files
bdemirbBaris Demir
andauthored
Arm backend: Fix the issue on conv->relu->permute->reshape(5D) (#18136)
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell Signed-off-by: Baris Demir <baris.demir@arm.com> Co-authored-by: Baris Demir <baris.demir@arm.com>
1 parent 88dc743 commit f1f7398

File tree

7 files changed

+580
-10
lines changed

7 files changed

+580
-10
lines changed

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@
138138
RewriteBoolToFp32CastViaInt8Pass,
139139
)
140140
from .rewrite_conv_pass import RewriteConvPass # noqa
141+
from .rewrite_high_rank_singleton_permute_pass import ( # noqa
142+
RewriteHighRankSingletonPermutePass,
143+
)
141144
from .rewrite_index_put_pass import RewriteIndexPutPass # noqa
142145
from .rewrite_le_lt_to_ge_gt_pass import RewriteLeLtToGeGtPass # noqa
143146
from .rewrite_matmul import RewriteMatmulPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
RewriteBoolBitwiseToLogicalPass,
122122
RewriteBoolToFp32CastViaInt8Pass,
123123
RewriteConvPass,
124+
RewriteHighRankSingletonPermutePass,
124125
RewriteIndexPutPass,
125126
RewriteLeLtToGeGtPass,
126127
RewriteMatmulPass,
@@ -366,6 +367,7 @@ def _tosa_pipeline(
366367
CastToInt32Pass(),
367368
BroadcastArgsPass(),
368369
ConvertPermuteSingletonToViewPass(),
370+
RewriteHighRankSingletonPermutePass(),
369371
FuseViewCopyTransformPass(),
370372
DecomposeConvWithInt16ActivationPass(),
371373
DecomposeSumPass(),
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Sequence, Set, Type
7+
8+
from executorch.backends.arm._passes import ArmPass
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
13+
class RewriteHighRankSingletonPermutePass(ArmPass):
14+
"""Rewrite high-rank permute via a lower-rank permute when singleton dims
15+
allow it.
16+
17+
For rank>4 tensors, some backends are fragile around direct high-rank
18+
TRANSPOSE. When singleton dimensions are present, we can rewrite:
19+
20+
permute(rank>4) -> view(remove singleton dims) -> permute(reduced rank) ->
21+
view(restore rank)
22+
23+
This keeps semantics unchanged while reducing the permute rank.
24+
25+
"""
26+
27+
_passes_required_after: Set[Type[ExportPass]] = set()
28+
29+
_PERMUTE_OPS = (
30+
exir_ops.edge.aten.permute.default,
31+
exir_ops.edge.aten.permute_copy.default,
32+
)
33+
34+
@staticmethod
35+
def _extract_permutation(permutation_arg: object) -> tuple[int, ...] | None:
36+
if not isinstance(permutation_arg, (list, tuple)):
37+
return None
38+
if not all(isinstance(dim, int) for dim in permutation_arg):
39+
return None
40+
return tuple(permutation_arg)
41+
42+
@staticmethod
43+
def _normalize_permutation(
44+
permutation: Sequence[int], rank: int
45+
) -> tuple[int, ...]:
46+
return tuple(dim % rank for dim in permutation)
47+
48+
def call_operator(self, op, args, kwargs, meta):
49+
if op not in self._PERMUTE_OPS:
50+
return super().call_operator(op, args, kwargs, meta)
51+
if len(args) < 2:
52+
return super().call_operator(op, args, kwargs, meta)
53+
if not hasattr(args[0], "data"):
54+
return super().call_operator(op, args, kwargs, meta)
55+
if "val" not in meta or not hasattr(meta["val"], "shape"):
56+
return super().call_operator(op, args, kwargs, meta)
57+
58+
permutation = self._extract_permutation(args[1])
59+
if permutation is None:
60+
return super().call_operator(op, args, kwargs, meta)
61+
62+
input_shape = list(args[0].data.shape)
63+
output_shape = list(meta["val"].shape)
64+
rank = len(input_shape)
65+
if rank <= 4 or len(output_shape) != rank:
66+
return super().call_operator(op, args, kwargs, meta)
67+
68+
normalized_permutation = self._normalize_permutation(permutation, rank)
69+
singleton_axes = [axis for axis, dim in enumerate(input_shape) if dim == 1]
70+
if not singleton_axes:
71+
return super().call_operator(op, args, kwargs, meta)
72+
73+
non_singleton_axes = [
74+
axis for axis in range(rank) if axis not in singleton_axes
75+
]
76+
reduced_rank = len(non_singleton_axes)
77+
if reduced_rank > 4:
78+
return super().call_operator(op, args, kwargs, meta)
79+
80+
axis_to_reduced_axis = {
81+
axis: idx for idx, axis in enumerate(non_singleton_axes)
82+
}
83+
reduced_permutation = tuple(
84+
axis_to_reduced_axis[axis]
85+
for axis in normalized_permutation
86+
if axis in axis_to_reduced_axis
87+
)
88+
expected_axes = tuple(range(reduced_rank))
89+
if tuple(sorted(reduced_permutation)) != expected_axes:
90+
return super().call_operator(op, args, kwargs, meta)
91+
92+
reduced_input_shape = [input_shape[axis] for axis in non_singleton_axes]
93+
reduced_input = super().call_operator(
94+
exir_ops.edge.aten.view_copy.default,
95+
(args[0], reduced_input_shape),
96+
{},
97+
meta,
98+
)
99+
if reduced_permutation == expected_axes:
100+
reduced_output = reduced_input
101+
else:
102+
reduced_output = super().call_operator(
103+
op,
104+
(reduced_input, reduced_permutation),
105+
kwargs,
106+
meta,
107+
)
108+
return super().call_operator(
109+
exir_ops.edge.aten.view_copy.default,
110+
(reduced_output, output_shape),
111+
{},
112+
meta,
113+
)

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,23 @@ def _channels_last_inverse_order(rank: int, spatial_rank: int) -> tuple[int, ...
113113
inverse[axis] = idx
114114
return tuple(inverse)
115115

116+
def _infer_dim_order_for_node(
117+
self, node: torch.fx.Node, node_data: torch.Tensor, spatial_rank: int
118+
) -> tuple[int, ...]:
119+
rank = node_data.dim()
120+
121+
# Inputs and outputs preserve their externally-declared dim order.
122+
if _is_input(node, self.exported_program) or node.op == "output":
123+
return node_data.dim_order()
124+
125+
# Conv transpose weights are serialized in OHWI layout.
126+
if rank == 4 and _is_transpose_conv2d_weight(node):
127+
return (1, 2, 3, 0)
128+
129+
if rank >= 4:
130+
return self._channels_last_order(rank, spatial_rank)
131+
return tuple(range(rank))
132+
116133
def _initial_spatial_rank(self, node: torch.fx.Node) -> int:
117134
"""Infer the initial spatial rank based on the current rank, input node
118135
spatial ranks and node target. A spatial dimension includes Height,
@@ -459,15 +476,7 @@ def call(self, graph_module: torch.fx.GraphModule):
459476
continue
460477
node_data = get_first_fake_tensor(node).data
461478
spatial_rank = node.meta["tosa_spatial_rank"]
462-
if _is_input(node, self.exported_program) or node.op == "output":
463-
dim_order = node_data.dim_order()
464-
else:
465-
if node_data.dim() == 4 and _is_transpose_conv2d_weight(node):
466-
dim_order = (1, 2, 3, 0)
467-
elif node_data.dim() >= 4:
468-
dim_order = self._channels_last_order(node_data.dim(), spatial_rank)
469-
else:
470-
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]
479+
dim_order = self._infer_dim_order_for_node(node, node_data, spatial_rank)
471480
node.meta["tosa_dim_order"] = dim_order
472481

473482
# Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.

0 commit comments

Comments
 (0)