From d013b83083febbae431e8533d857e83cab5fac89 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 27 Mar 2025 16:18:13 -0700 Subject: [PATCH] [Bug Fix] Fix padding when running in NHWC --- .../operators/op_static_constant_pad.py | 14 +++++- .../test/ops/test_static_constant_pad.py | 45 +++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/backends/xnnpack/operators/op_static_constant_pad.py b/backends/xnnpack/operators/op_static_constant_pad.py index 3381227a885..c14db8192a2 100644 --- a/backends/xnnpack/operators/op_static_constant_pad.py +++ b/backends/xnnpack/operators/op_static_constant_pad.py @@ -7,6 +7,7 @@ from typing import cast, Dict, List import torch + from executorch.backends.xnnpack.operators.node_visitor import ( get_tensor_value, NodeVisitor, @@ -17,7 +18,11 @@ XNNStaticConstantPad, XNode, ) -from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node +from executorch.backends.xnnpack.utils.utils import ( + check_or_raise, + get_input_node, + PERM_NCHW_TO_NHWC, +) @register_node_visitor @@ -113,8 +118,15 @@ def define_node( # b) # tuple[0] = prepadding dim[-1] # tuple[1] = postpadding dim[-1] + is_channels_last = node.meta.get("XNN_NHWC_NODE", False) pre_paddings = all_paddings[-2::-2] # even index elements in reverse order post_paddings = all_paddings[::-2] # odd index elements in reverse order + if is_channels_last: + check_or_raise(len(pre_paddings) == 4, "Expecting prepaddings to be 4D") + check_or_raise(len(post_paddings) == 4, "Expecting postpaddings to be 4D") + + pre_paddings = [pre_paddings[i] for i in PERM_NCHW_TO_NHWC] + post_paddings = [post_paddings[i] for i in PERM_NCHW_TO_NHWC] # the padding value, which defaults to 0.0 padding_value = cast(float, node.args[2]) if len(node.args) > 2 else 0.0 diff --git a/backends/xnnpack/test/ops/test_static_constant_pad.py b/backends/xnnpack/test/ops/test_static_constant_pad.py index c5d103f596a..9613308f6a6 100644 --- a/backends/xnnpack/test/ops/test_static_constant_pad.py +++ b/backends/xnnpack/test/ops/test_static_constant_pad.py @@ -14,6 +14,30 @@ class TestStaticConstantPad(unittest.TestCase): def setUp(self): torch._dynamo.reset() + class NHWCStaticConstantPad(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=1) + self.conv2 = torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=1) + + def forward(self, x): + a = self.conv1(x) + pad_6 = (1, 2, 3, 4, 5, 6) + a = torch.nn.functional.pad( + input=a, + pad=pad_6, + mode="constant", + value=3.1, + ) + # tensorshape = [1, 13, 10, 7] + a = self.conv2(a) + + return a + + def sample_inputs(self): + # NCHW + return (torch.randn(1, 2, 3, 4),) + class StaticConstantPadFunctional(torch.nn.Module): def __init__(self): super().__init__() @@ -205,3 +229,24 @@ def test_qs8_static_constant_pad_2d(self): .serialize() .run_method_and_compare_outputs() ) + + def test_fp32_static_constant_pad_nhwc(self): + conv = self.NHWCStaticConstantPad() + inputs = conv.sample_inputs() + ( + Tester(conv, inputs) + .export() + .check_count({"torch.ops.aten.pad.default": 1}) + .dump_artifact() + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default", + "executorch_exir_dialects_edge__ops_aten_convolution_default", + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + )