Skip to content

Commit 3492e53

Browse files
committed
[Bug Fix] Fix padding when running in NHWC
1 parent 65ebabb commit 3492e53

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

backends/xnnpack/operators/op_static_constant_pad.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from typing import cast, Dict, List
88

99
import torch
10+
11+
from backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
12+
ChannelsLastTaggedReshapePass,
13+
)
1014
from executorch.backends.xnnpack.operators.node_visitor import (
1115
get_tensor_value,
1216
NodeVisitor,
@@ -17,7 +21,11 @@
1721
XNNStaticConstantPad,
1822
XNode,
1923
)
20-
from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node
24+
from executorch.backends.xnnpack.utils.utils import (
25+
check_or_raise,
26+
get_input_node,
27+
PERM_NCHW_TO_NHWC,
28+
)
2129

2230

2331
@register_node_visitor
@@ -113,8 +121,17 @@ def define_node(
113121
# b)
114122
# tuple[0] = prepadding dim[-1]
115123
# tuple[1] = postpadding dim[-1]
124+
is_channels_last = node.meta.get(
125+
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False
126+
)
116127
pre_paddings = all_paddings[-2::-2] # even index elements in reverse order
117128
post_paddings = all_paddings[::-2] # odd index elements in reverse order
129+
if is_channels_last:
130+
check_or_raise(len(pre_paddings) == 4, "Expecting prepaddings to be 4D")
131+
check_or_raise(len(post_paddings) == 4, "Expecting postpaddings to be 4D")
132+
133+
pre_paddings = [pre_paddings[i] for i in PERM_NCHW_TO_NHWC]
134+
post_paddings = [post_paddings[i] for i in PERM_NCHW_TO_NHWC]
118135

119136
# the padding value, which defaults to 0.0
120137
padding_value = cast(float, node.args[2]) if len(node.args) > 2 else 0.0

backends/xnnpack/test/ops/test_static_constant_pad.py

+45
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,30 @@ class TestStaticConstantPad(unittest.TestCase):
1414
def setUp(self):
1515
torch._dynamo.reset()
1616

17+
class NHWCStaticConstantPad(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
self.conv1 = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=1)
21+
self.conv2 = torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=1)
22+
23+
def forward(self, x):
24+
a = self.conv1(x)
25+
pad_6 = (1, 2, 3, 4, 5, 6)
26+
a = torch.nn.functional.pad(
27+
input=a,
28+
pad=pad_6,
29+
mode="constant",
30+
value=3.1,
31+
)
32+
# tensorshape = [1, 13, 10, 7]
33+
a = self.conv2(a)
34+
35+
return a
36+
37+
def sample_inputs(self):
38+
# NCHW
39+
return (torch.randn(1, 2, 3, 4),)
40+
1741
class StaticConstantPadFunctional(torch.nn.Module):
1842
def __init__(self):
1943
super().__init__()
@@ -205,3 +229,24 @@ def test_qs8_static_constant_pad_2d(self):
205229
.serialize()
206230
.run_method_and_compare_outputs()
207231
)
232+
233+
def test_fp32_static_constant_pad_nhwc(self):
234+
conv = self.NHWCStaticConstantPad()
235+
inputs = conv.sample_inputs()
236+
(
237+
Tester(conv, inputs)
238+
.export()
239+
.check_count({"torch.ops.aten.pad.default": 1})
240+
.dump_artifact()
241+
.to_edge_transform_and_lower()
242+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
243+
.check_not(
244+
[
245+
"executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default",
246+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
247+
]
248+
)
249+
.to_executorch()
250+
.serialize()
251+
.run_method_and_compare_outputs()
252+
)

0 commit comments

Comments
 (0)