Skip to content

Commit db00a4c

Browse files
3l1facebook-github-bot
authored andcommitted
Eliminate redundant NCHW↔NHWC permute_copy and NHWC-safe view_copy transposes in ToTosaMemoryFormatPass
Summary: Two optimizations in ToTosaMemoryFormatPass to reduce TOSA TRANSPOSE nodes: 1. NHWC-safe reshape detection: When a 4D→4D view_copy has monotonic shape_indices on the raw shapes and preserves the last dimension (NHWC channel), skip inserting input/output transposes. The view_copy can operate directly on NHWC data. 2. Redundant permute_copy elimination: Model-level permute_copy ops whose permutation matches channels_last_order (NCHW→NHWC) or its inverse (NHWC→NCHW) are redundant with the tosa_dim_order annotation that already handles format conversion. Replace them with view_copy (identity reshape) to avoid generating TOSA TRANSPOSE nodes. Handles both 4D (rank>=4, sr>=2) and 3D (rank>=3, sr>=1) permutations. For the CC EMG model, this reduces Vela Transpose entries from 75→33 (-56%), Transpose op cycles from 33.4K→6.1K (-82%), and NPU operators from 367→329 (-38). Also removes the failed ReorderToNHWCPass which targeted permute_copy→ view_copy→permute_copy patterns that don't exist in the Edge IR graph. Reviewed By: digantdesai Differential Revision: D96432610
1 parent 3604d3e commit db00a4c

File tree

2 files changed

+325
-2
lines changed

2 files changed

+325
-2
lines changed

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 171 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,21 @@ def get_batch_prod_dim(shape, spatial_rank):
222222

223223
return (N_old != N_new) or (C_old != C_new)
224224

225+
@staticmethod
226+
def _is_nop_transpose(shape, perm) -> bool:
227+
"""Return ``True`` when a transpose only permutes size-1 dimensions.
228+
229+
A transpose is a NOP (no-operation) when the relative order of
230+
all non-size-1 dimensions is unchanged — permuting size-1 dims
231+
does not alter the physical byte layout.
232+
233+
Example: ``[14, 72, 1, 1]`` with perm ``(0, 1, 3, 2)`` → True
234+
(only the two trailing size-1 dims swap).
235+
"""
236+
old_order = [i for i, s in enumerate(shape) if s != 1]
237+
new_order = [i for i, s in zip(perm, [shape[p] for p in perm]) if s != 1]
238+
return old_order == new_order
239+
225240
@staticmethod
226241
def insert_input_transpose(node, input_node, graph_module):
227242
"""Ensure an input tensor is converted to channels-last ordering by
@@ -271,7 +286,7 @@ def insert_output_transpose(node, graph_module):
271286
# Guard: mem_format must be a true permutation for the current rank
272287
assert sorted(mem_format) == list(
273288
range(rank)
274-
), f"bad perm {mem_format} for rank {rank} in insert_input_transpose"
289+
), f"bad perm {mem_format} for rank {rank} in insert_output_transpose"
275290

276291
with graph_module.graph.inserting_after(node):
277292
permute_node = create_node(
@@ -296,6 +311,104 @@ def insert_output_transpose(node, graph_module):
296311
for user in users:
297312
user.replace_input_with(node, permute_node)
298313

314+
@staticmethod
315+
def _get_shape_indices(
316+
src_shape: list[int], tgt_shape: list[int]
317+
) -> list[list[int]] | None:
318+
"""Greedy dimension matching for reshape operations.
319+
320+
For each target dimension, greedily consumes contiguous source
321+
dimensions whose product equals the target size. Size-1 target
322+
dimensions that do not correspond to any source dimension produce
323+
empty index lists (inserted dims).
324+
325+
Returns ``None`` when no valid mapping exists.
326+
"""
327+
src_idx = 0
328+
result: list[list[int]] = []
329+
330+
for tgt_dim in tgt_shape:
331+
if tgt_dim <= 0:
332+
return None
333+
334+
indices: list[int] = []
335+
remaining = tgt_dim
336+
337+
while src_idx < len(src_shape) and remaining % src_shape[src_idx] == 0:
338+
indices.append(src_idx)
339+
remaining //= src_shape[src_idx]
340+
src_idx += 1
341+
if remaining == 1:
342+
break
343+
344+
if remaining != 1:
345+
return None
346+
347+
result.append(indices)
348+
349+
if src_idx != len(src_shape):
350+
return None
351+
352+
return result
353+
354+
@staticmethod
355+
def _is_monotonic(indices: list[list[int]]) -> bool:
356+
"""Return ``True`` when all non-empty index groups are strictly
357+
ordered — i.e. each group's indices follow the previous group's.
358+
"""
359+
last_max = -1
360+
for group in indices:
361+
if not group:
362+
continue
363+
if group[0] <= last_max:
364+
return False
365+
last_max = group[-1]
366+
return True
367+
368+
@staticmethod
369+
def _is_nhwc_safe_reshape(
370+
input_shape, output_shape, input_sr, output_sr # noqa: ARG004
371+
) -> bool:
372+
"""Detect whether a 4-D+ reshape can operate directly on NHWC data.
373+
374+
By the time ``ToTosaMemoryFormatPass`` runs, 4-D tensor shapes in
375+
``meta["val"]`` are already in NHWC physical order (the channel
376+
dimension sits at position ``rank - spatial_rank - 1``, not at
377+
position 1 as in NCHW). We therefore check the shape indices on
378+
the **raw** input/output shapes — no extra permutation is needed.
379+
380+
Returns ``True`` when:
381+
1. The reshape has monotonic shape_indices (each output dim maps
382+
to a contiguous, in-order group of input dims), AND
383+
2. The channel dimension is preserved alone (not merged with
384+
spatial dims).
385+
"""
386+
rank_in = len(input_shape)
387+
rank_out = len(output_shape)
388+
if rank_in < 4 or rank_out < 4:
389+
return False
390+
391+
indices = ToTosaMemoryFormatPass._get_shape_indices(
392+
list(input_shape), list(output_shape)
393+
)
394+
if indices is None:
395+
return False
396+
397+
if not ToTosaMemoryFormatPass._is_monotonic(indices):
398+
return False
399+
400+
# In the TOSA pipeline the physical memory order is NHWC.
401+
# The channel dimension in NHWC is always the **last** axis
402+
# (position ``rank - 1``). It must appear *alone* in its
403+
# output group — if it is merged with spatial dims the reshape
404+
# would reorder channel data and the optimisation is invalid.
405+
channel_idx = rank_in - 1
406+
for group in indices:
407+
if channel_idx in group:
408+
return len(group) == 1
409+
# Channel dim not consumed by any group — conservative reject.
410+
return False
411+
299412
@staticmethod
300413
def _insert_view_transpose(
301414
input_shape, output_shape, node, input_node, graph_module
@@ -317,6 +430,14 @@ def _insert_view_transpose(
317430
output_sr,
318431
)
319432

433+
# When the NHWC-space reshape has monotonic shape_indices the
434+
# view_copy can operate directly on NHWC data — no transposes
435+
# are needed.
436+
if channel_reshape and ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
437+
input_shape, output_shape, input_sr, output_sr
438+
):
439+
return
440+
320441
if (
321442
channel_reshape or nhwc_to_nchw
322443
) and ToTosaMemoryFormatPass.memory_format_differs(input_shape, input_sr):
@@ -345,10 +466,58 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
345466
- 1D/2D tensors
346467
347468
"""
348-
for node in graph_module.graph.nodes:
469+
for node in list(graph_module.graph.nodes):
349470
if node.op != "call_function":
350471
continue
351472

473+
# Eliminate model-level permute_copy ops that are redundant
474+
# with the tosa_dim_order annotation. When a permute_copy's
475+
# permutation matches the channels-last order (or its
476+
# inverse), the permute does the same NCHW<>NHWC conversion
477+
# that tosa_dim_order already handles -- keeping both would
478+
# double-convert. Replace with view_copy (identity reshape).
479+
if node.target in (
480+
exir_ops.edge.aten.permute_copy.default,
481+
exir_ops.edge.aten.permute.default,
482+
):
483+
perm = list(node.args[1])
484+
rank = len(perm)
485+
sr = node.meta.get("tosa_spatial_rank", 0)
486+
487+
if rank >= 3 and sr >= 1:
488+
cl_order = list(
489+
self._channels_last_order(rank, sr)
490+
)
491+
cl_inv = list(
492+
self._channels_last_inverse_order(rank, sr)
493+
)
494+
if perm == cl_order or perm == cl_inv:
495+
input_node = node.args[0]
496+
output_shape = list(node.meta["val"].shape)
497+
with graph_module.graph.inserting_before(node):
498+
# Create a CONST_SHAPE node for the shape arg,
499+
# matching what InsertConstShapesPass does for
500+
# normal view_copy nodes. This ensures
501+
# op_view.py sees inputs[1].name as expected.
502+
const_shape_node = graph_module.graph.call_function(
503+
exir_ops.backend.tosa.CONST_SHAPE.default,
504+
(output_shape,),
505+
)
506+
const_shape_node.meta["val"] = output_shape
507+
const_shape_node.meta["tosa_dim_order"] = node.meta.get(
508+
"tosa_dim_order", tuple(range(rank))
509+
)
510+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
511+
const_shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
512+
view_node = graph_module.graph.call_function(
513+
exir_ops.edge.aten.view_copy.default,
514+
(input_node, const_shape_node),
515+
)
516+
view_node.meta = dict(node.meta)
517+
node.replace_all_uses_with(view_node)
518+
graph_module.graph.erase_node(node)
519+
continue
520+
352521
# Transpose views
353522
elif node.target == exir_ops.edge.aten.view_copy.default:
354523
input_node = node.args[0]

backends/arm/test/passes/test_to_tosa_memory_format.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,76 @@ def get_inputs(self) -> input_t:
177177
return (torch.rand(4, 4, 4, 4),)
178178

179179

180+
class NHWCSafeSpatialMerge(torch.nn.Module):
181+
"""Test-module with a 4D->4D reshape that merges spatial dims H*W while
182+
preserving the last-dim channel.
183+
184+
For models with view_copy shapes [1,2,14,72]->[1,28,1,72] where C=2
185+
sits at NCHW position 1 and the last dim (72) is the NHWC channel that gets
186+
preserved. ``_is_nhwc_safe_reshape`` detects that shape_indices on the raw
187+
shapes are monotonic with the last dim alone, so no transposes are inserted
188+
around the view_copy.
189+
190+
Setup: conv2d (forces NHWC, C=2) -> view_copy -> add (keeps in NHWC).
191+
"""
192+
193+
ops_before_pass: Dict[str, int] = {}
194+
# Only the 2 I/O transposes for the conv, NO extra transposes from view_copy
195+
ops_after_pass: Dict[str, int] = {
196+
"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2
197+
}
198+
ops_not_after_pass: List[str] = []
199+
200+
def __init__(self):
201+
super().__init__()
202+
self.conv = torch.nn.Conv2d(
203+
in_channels=2, out_channels=2, kernel_size=1, bias=False
204+
)
205+
206+
def forward(self, x: torch.Tensor) -> torch.Tensor:
207+
x = self.conv(x) # forces NHWC path; output [1, 2, 14, 72]
208+
x = x.view(1, 28, 1, 72) # spatial merge: H*W=2*14->28, last dim 72 preserved
209+
return x + x # keep result 4-D in NHWC
210+
211+
def get_inputs(self) -> input_t:
212+
return (torch.randn(1, 2, 14, 72),)
213+
214+
215+
class NHWCUnsafeChannelChange(torch.nn.Module):
216+
"""Test-module with a 4D->4D reshape that is NOT NHWC-safe because the
217+
target shape cannot be produced by a monotonic merge of NHWC input dims.
218+
The pass MUST still insert transposes around the view_copy.
219+
"""
220+
221+
ops_before_pass: Dict[str, int] = {}
222+
# conv I/O transposes (2) + view_copy transposes (2) = 4
223+
ops_after_pass: Dict[str, int] = {
224+
"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4
225+
}
226+
ops_not_after_pass: List[str] = []
227+
228+
def __init__(self):
229+
super().__init__()
230+
self.conv = torch.nn.Conv2d(
231+
in_channels=72, out_channels=72, kernel_size=1, bias=False
232+
)
233+
234+
def forward(self, x: torch.Tensor) -> torch.Tensor:
235+
x = self.conv(x) # output [1, 72, 2, 14]
236+
x = x.view(1, 14, 2, 72) # not NHWC-safe (channels shuffled)
237+
return x + x
238+
239+
def get_inputs(self) -> input_t:
240+
return (torch.randn(1, 72, 2, 14),)
241+
242+
180243
modules: Dict[str, ModuleMetadata] = {
181244
"no_nhwc": NoNHWC(),
182245
"parallel_clusters": ParallelClusters(),
183246
"serial_clusters": SerialClusters(),
184247
"reshapes": Reshapes(),
248+
"nhwc_safe_spatial_merge": NHWCSafeSpatialMerge(),
249+
"nhwc_unsafe_channel_change": NHWCUnsafeChannelChange(),
185250
}
186251

187252

@@ -209,3 +274,92 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No
209274
module_nn = cast(torch.nn.Module, module)
210275
pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), [])
211276
pipeline.run()
277+
278+
279+
# --- Direct unit tests for NHWC-safe reshape helpers ---
280+
281+
282+
def test_get_shape_indices_spatial_merge():
283+
"""[1,2,14,72] -> [1,28,1,72]: merge H*W, insert size-1 dim, preserve C."""
284+
indices = ToTosaMemoryFormatPass._get_shape_indices(
285+
[1, 2, 14, 72], [1, 28, 1, 72]
286+
)
287+
assert indices == [[0], [1, 2], [], [3]]
288+
289+
290+
def test_get_shape_indices_identity():
291+
"""Same shape => each dim maps to itself."""
292+
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [2, 3, 4])
293+
assert indices == [[0], [1], [2]]
294+
295+
296+
def test_get_shape_indices_full_merge():
297+
"""[2, 3, 4] -> [24]: merge all dims into one."""
298+
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [24])
299+
assert indices == [[0, 1, 2]]
300+
301+
302+
def test_get_shape_indices_incompatible():
303+
"""Sizes that don't divide => None."""
304+
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 5], [6, 4])
305+
assert indices is None
306+
307+
308+
def test_get_shape_indices_size_one_insert():
309+
"""[6, 4] -> [6, 1, 4]: inserted size-1 dim in the middle."""
310+
indices = ToTosaMemoryFormatPass._get_shape_indices([6, 4], [6, 1, 4])
311+
assert indices is not None
312+
assert indices == [[0], [], [1]]
313+
314+
315+
def test_is_monotonic_true():
316+
assert ToTosaMemoryFormatPass._is_monotonic([[0], [1, 2], [], [3]])
317+
assert ToTosaMemoryFormatPass._is_monotonic([[0], [], [1], [2, 3]])
318+
assert ToTosaMemoryFormatPass._is_monotonic([[], [0, 1, 2]])
319+
320+
321+
def test_is_monotonic_false():
322+
assert not ToTosaMemoryFormatPass._is_monotonic([[1], [0]])
323+
assert not ToTosaMemoryFormatPass._is_monotonic([[0, 2], [1]])
324+
325+
326+
def test_is_nhwc_safe_forward():
327+
"""Shapes already NHWC by the time the pass runs.
328+
[1,2,14,72] -> [1,28,1,72], sr=2 -> NHWC-safe (spatial merge, C=72 preserved).
329+
"""
330+
assert ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
331+
[1, 2, 14, 72], [1, 28, 1, 72], input_sr=2, output_sr=2
332+
)
333+
334+
335+
def test_is_nhwc_safe_non_4d():
336+
"""Reshapes below rank 4 are never NHWC-safe."""
337+
assert not ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
338+
[6, 4], [24], input_sr=0, output_sr=0
339+
)
340+
341+
342+
def test_is_nop_transpose_size1_swap():
343+
"""[14, 72, 1, 1] with perm (0, 1, 3, 2) only swaps trailing size-1 dims."""
344+
assert ToTosaMemoryFormatPass._is_nop_transpose([14, 72, 1, 1], (0, 1, 3, 2))
345+
346+
347+
def test_is_nop_transpose_real_reorder():
348+
"""[14, 72, 1, 1] with perm (1, 0, 2, 3) swaps non-size-1 dims."""
349+
assert not ToTosaMemoryFormatPass._is_nop_transpose([14, 72, 1, 1], (1, 0, 2, 3))
350+
351+
352+
def test_is_nop_transpose_all_size1():
353+
"""[1, 1, 1, 1] with any perm is always a NOP."""
354+
assert ToTosaMemoryFormatPass._is_nop_transpose([1, 1, 1, 1], (3, 2, 1, 0))
355+
356+
357+
def test_is_nop_transpose_identity():
358+
"""Identity permutation is always a NOP."""
359+
assert ToTosaMemoryFormatPass._is_nop_transpose([2, 3, 4], (0, 1, 2))
360+
361+
362+
def test_is_nop_transpose_nhwc_on_size1_spatial():
363+
"""[1, 28, 1, 72] with channels_last (0,2,3,1): non-size-1 dims 28,72
364+
change relative order (28→pos3, 72→pos2) → NOT a NOP."""
365+
assert not ToTosaMemoryFormatPass._is_nop_transpose([1, 28, 1, 72], (0, 2, 3, 1))

0 commit comments

Comments
 (0)