Skip to content

Commit 1c565e1

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add NHWC version of max_pool2d (#18239)
Summary: As titled. Should perform better and also allow removing some permutes when convolutions are also moved to channel last. Reviewed By: hsharma35 Differential Revision: D96869747
1 parent 22feba7 commit 1c565e1

File tree

11 files changed

+398
-17
lines changed

11 files changed

+398
-17
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,15 @@
309309
- arg_meta: null
310310
kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out
311311

312-
- func: cadence::quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
312+
- func: cadence::quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
313313
kernels:
314314
- arg_meta: null
315-
kernel_name: impl::generic::quantized_max_pool2d_out
315+
kernel_name: impl::generic::quantized_max_pool2d_nchw_out
316+
317+
- func: cadence::quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
318+
kernels:
319+
- arg_meta: null
320+
kernel_name: impl::generic::quantized_max_pool2d_nhwc_out
316321

317322
- func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
318323
kernels:

backends/cadence/aot/ops_registrations.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,16 @@ def register_fake(
214214
)
215215

216216
lib.define(
217-
"quantized_max_pool2d(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
217+
"quantized_max_pool2d_nchw(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
218218
)
219219
lib.define(
220-
"quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
220+
"quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
221+
)
222+
lib.define(
223+
"quantized_max_pool2d_nhwc(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
224+
)
225+
lib.define(
226+
"quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
221227
)
222228

223229
lib.define(
@@ -2277,8 +2283,8 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta(
22772283
return input.new_empty(input.size(), dtype=input.dtype)
22782284

22792285

2280-
@register_fake("cadence::quantized_max_pool2d")
2281-
def quantized_max_pool2d_meta(
2286+
@register_fake("cadence::quantized_max_pool2d_nchw")
2287+
def quantized_max_pool2d_nchw_meta(
22822288
input: torch.Tensor,
22832289
kernel_size: list[int],
22842290
stride: list[int],
@@ -2318,6 +2324,47 @@ def quantized_max_pool2d_meta(
23182324
return input.new_empty([batch, channels, height_out, width_out], dtype=input.dtype)
23192325

23202326

2327+
@register_fake("cadence::quantized_max_pool2d_nhwc")
2328+
def quantized_max_pool2d_nhwc_meta(
2329+
input: torch.Tensor,
2330+
kernel_size: list[int],
2331+
stride: list[int],
2332+
padding: list[int],
2333+
dilation: list[int],
2334+
ceil_mode: bool,
2335+
) -> torch.Tensor:
2336+
assert (
2337+
len(kernel_size) == 2
2338+
), f"kernel_size must have 2 elements, got {len(kernel_size)}"
2339+
assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}"
2340+
assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}"
2341+
assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}"
2342+
assert (
2343+
len(input.size()) == 4
2344+
), f"input must be 4D (N, H, W, C), got {len(input.size())}D"
2345+
2346+
batch = input.size(0)
2347+
height_in = input.size(1)
2348+
width_in = input.size(2)
2349+
channels = input.size(3)
2350+
2351+
height_out_raw = (
2352+
height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
2353+
) / stride[0] + 1
2354+
width_out_raw = (
2355+
width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
2356+
) / stride[1] + 1
2357+
2358+
if ceil_mode:
2359+
height_out = ceil(height_out_raw)
2360+
width_out = ceil(width_out_raw)
2361+
else:
2362+
height_out = int(height_out_raw)
2363+
width_out = int(width_out_raw)
2364+
2365+
return input.new_empty([batch, height_out, width_out, channels], dtype=input.dtype)
2366+
2367+
23212368
@register_fake("cadence::fully_connected")
23222369
def fully_connected_meta(
23232370
src: torch.Tensor,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def get_anchors(
459459
)
460460

461461
def replacement_op(self) -> OpOverload:
462-
return torch.ops.cadence.quantized_max_pool2d.default
462+
return torch.ops.cadence.quantized_max_pool2d_nchw.default
463463

464464

465465
class MaxPool2dWithoutIndicesPattern(QuantizationPattern):
@@ -498,10 +498,14 @@ def get_anchors(
498498
)
499499

500500
def replacement_op(self) -> OpOverload:
501-
return torch.ops.cadence.quantized_max_pool2d.default
501+
return torch.ops.cadence.quantized_max_pool2d_nchw.default
502502

503503

504+
# This is a base class for ReLU
505+
504506
# This is a base class for ReLU, since it can be used with two different aten ops
507+
508+
505509
class ReluBasePattern(QuantizationPattern):
506510
@abstractmethod
507511
def partition_types(self) -> List[OpOverload]:

backends/cadence/aot/ref_implementations.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,8 +1868,8 @@ def rms_norm(
18681868
return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X)
18691869

18701870

1871-
@impl_tracked(m, "quantized_max_pool2d")
1872-
def quantized_max_pool2d(
1871+
@impl_tracked(m, "quantized_max_pool2d_nchw")
1872+
def quantized_max_pool2d_nchw(
18731873
input: torch.Tensor,
18741874
kernel_size: list[int],
18751875
stride: list[int],
@@ -1897,6 +1897,37 @@ def quantized_max_pool2d(
18971897
)
18981898

18991899

1900+
@impl_tracked(m, "quantized_max_pool2d_nhwc")
1901+
def quantized_max_pool2d_nhwc(
1902+
input: torch.Tensor,
1903+
kernel_size: list[int],
1904+
stride: list[int],
1905+
padding: list[int],
1906+
dilation: list[int],
1907+
ceil_mode: bool,
1908+
) -> torch.Tensor:
1909+
"""
1910+
Quantized max pooling in NHWC layout.
1911+
1912+
Converts NHWC→NCHW, performs max pooling, then converts back NCHW→NHWC.
1913+
"""
1914+
# Convert NHWC [N, H, W, C] to NCHW [N, C, H, W]
1915+
input_nchw = input.permute(0, 3, 1, 2).contiguous()
1916+
1917+
# Call the NCHW version
1918+
output_nchw = quantized_max_pool2d_nchw(
1919+
input_nchw,
1920+
kernel_size=kernel_size,
1921+
stride=stride,
1922+
padding=padding,
1923+
dilation=dilation,
1924+
ceil_mode=ceil_mode,
1925+
)
1926+
1927+
# Convert NCHW [N, C, H_out, W_out] back to NHWC [N, H_out, W_out, C]
1928+
return output_nchw.permute(0, 2, 3, 1).contiguous()
1929+
1930+
19001931
@impl_tracked(m, "where_Scalar")
19011932
def where_Scalar(
19021933
condition: torch.Tensor,

backends/cadence/aot/replace_ops.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,6 +1182,67 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
11821182
return True
11831183

11841184

1185+
@register_cadence_pass(CadencePassAttribute(opt_level=3))
1186+
class ReplaceMaxPool2dWithChannelLastMaxPool2dPass(RemoveOrReplacePassInterface):
1187+
"""
1188+
Replace NCHW max pooling with NHWC (channel-last) max pooling by adding
1189+
permute operations before and after the max pooling.
1190+
"""
1191+
1192+
@property
1193+
def targets(self) -> list[EdgeOpOverload]:
1194+
return [
1195+
exir_ops.edge.cadence.quantized_max_pool2d_nchw.default,
1196+
]
1197+
1198+
def _change_nchw_to_nhwc(
1199+
self, graph: torch.fx.Graph, node: torch.fx.Node
1200+
) -> torch.fx.Node:
1201+
"""Convert NCHW format to NHWC format."""
1202+
permute_node = graph.call_function(
1203+
exir_ops.edge.aten.permute_copy.default, (node, [0, 2, 3, 1]), {}
1204+
)
1205+
permute_node.meta = node.meta
1206+
return permute_node
1207+
1208+
def _change_nhwc_to_nchw(
1209+
self, graph: torch.fx.Graph, node: torch.fx.Node
1210+
) -> torch.fx.Node:
1211+
"""Convert NHWC format to NCHW format."""
1212+
permute_node = graph.call_function(
1213+
exir_ops.edge.aten.permute_copy.default, (node, [0, 3, 1, 2]), {}
1214+
)
1215+
permute_node.meta = node.meta
1216+
return permute_node
1217+
1218+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
1219+
graph = node.graph
1220+
1221+
# Get input node
1222+
input_node = cast(torch.fx.Node, node.args[0])
1223+
1224+
with graph.inserting_before(node):
1225+
# Convert input from NCHW to NHWC
1226+
input_nhwc = self._change_nchw_to_nhwc(graph, input_node)
1227+
1228+
# Create the NHWC max pooling with the same args (kernel_size, stride, padding, dilation, ceil_mode)
1229+
new_args = (input_nhwc,) + tuple(node.args[1:])
1230+
1231+
new_pool = graph.call_function(
1232+
exir_ops.edge.cadence.quantized_max_pool2d_nhwc.default,
1233+
new_args,
1234+
node.kwargs,
1235+
)
1236+
new_pool.meta = node.meta
1237+
1238+
# Convert output back from NHWC to NCHW
1239+
nchw_output = self._change_nhwc_to_nchw(graph, new_pool)
1240+
1241+
# Replace all uses with the final output
1242+
node.replace_all_uses_with(nchw_output)
1243+
return True
1244+
1245+
11851246
@register_cadence_pass(CadencePassAttribute(opt_level=3))
11861247
class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface):
11871248
"""
@@ -2561,6 +2622,7 @@ class CadenceReplaceOpsInGraph:
25612622
ReplacePadWithCatPass,
25622623
ReplaceConstantPadNdWithSlicePass,
25632624
ReplaceConvWithChannelLastConvPass,
2625+
ReplaceMaxPool2dWithChannelLastMaxPool2dPass,
25642626
ReplaceTrivialConvWithLinear,
25652627
ReplaceConvWithIm2RowAndLinear,
25662628
ReplaceTransposedConvWithLinearPass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
ReplaceLinearWithFullyConnectedOpPass,
3737
ReplaceLogicalNotBooleanWhereWithWherePass,
3838
ReplaceMatmulWithTransposedMatmulPass,
39+
ReplaceMaxPool2dWithChannelLastMaxPool2dPass,
3940
ReplaceMMWithAddMMPass,
4041
ReplaceMulTensorWithMulAndFullOpsPass,
4142
ReplaceNopTransposeOrPermuteWithViewPass,
@@ -2586,6 +2587,59 @@ def test_cat_insert_transpose(self) -> None:
25862587
)
25872588

25882589

2590+
class TestReplaceMaxPool2dWithChannelLastMaxPool2dPass(unittest.TestCase):
2591+
def test_replace_max_pool2d_nchw_with_nhwc(self) -> None:
2592+
# Create a graph with a single quantized_max_pool2d_nchw node.
2593+
x = torch.randint(0, 100, (1, 3, 8, 8), dtype=torch.int8)
2594+
gm = single_op_builder(
2595+
placeholders=(x,),
2596+
op=exir_ops.edge.cadence.quantized_max_pool2d_nchw.default,
2597+
args=(x, [2, 2], [2, 2], [0, 0], [1, 1], False),
2598+
)
2599+
self.assertEqual(
2600+
count_node(gm, exir_ops.edge.cadence.quantized_max_pool2d_nchw.default), 1
2601+
)
2602+
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
2603+
2604+
# Deepcopy before the pass
2605+
original = copy.deepcopy(gm)
2606+
2607+
# Apply replacement pass.
2608+
p = ReplaceMaxPool2dWithChannelLastMaxPool2dPass()
2609+
result = p.call(gm)
2610+
self.assertTrue(result.modified)
2611+
gm_after_replacement = result.graph_module
2612+
2613+
# Check that replacement was made.
2614+
self.assertEqual(
2615+
count_node(
2616+
gm_after_replacement,
2617+
exir_ops.edge.cadence.quantized_max_pool2d_nhwc.default,
2618+
),
2619+
1,
2620+
)
2621+
self.assertEqual(
2622+
count_node(
2623+
gm_after_replacement,
2624+
exir_ops.edge.cadence.quantized_max_pool2d_nchw.default,
2625+
),
2626+
0,
2627+
)
2628+
# Two permutes: one for input NCHW->NHWC, one for output NHWC->NCHW
2629+
self.assertEqual(
2630+
count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default),
2631+
2,
2632+
)
2633+
2634+
# Validate numerical accuracy
2635+
validate(
2636+
original,
2637+
gm_after_replacement,
2638+
(x,),
2639+
"ReplaceMaxPool2dWithChannelLastMaxPool2dPass",
2640+
)
2641+
2642+
25892643
class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase):
25902644
def _get_slice_empty_gm(self) -> tuple[torch.fx.GraphModule, torch.Tensor]:
25912645
builder = GraphBuilder()

backends/cadence/generic/operators/op_quantized_max_pool2d.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ using ::executorch::runtime::KernelRuntimeContext;
2727
namespace {
2828

2929
template <typename T>
30-
void quantized_max_pool2d_impl(
30+
void quantized_max_pool2d_nchw_impl(
3131
const Tensor& input,
3232
IntArrayRef kernel_size,
3333
IntArrayRef stride,
@@ -98,7 +98,7 @@ void quantized_max_pool2d_impl(
9898

9999
} // namespace
100100

101-
Tensor& quantized_max_pool2d_out(
101+
Tensor& quantized_max_pool2d_nchw_out(
102102
ET_UNUSED KernelRuntimeContext& ctx,
103103
const Tensor& input,
104104
IntArrayRef kernel_size,
@@ -107,24 +107,24 @@ Tensor& quantized_max_pool2d_out(
107107
IntArrayRef dilation,
108108
bool ceil_mode,
109109
Tensor& output) {
110-
#define typed_quantized_max_pool2d(ctype, dtype) \
110+
#define typed_quantized_max_pool2d_nchw(ctype, dtype) \
111111
case ScalarType::dtype: { \
112-
quantized_max_pool2d_impl<ctype>( \
112+
quantized_max_pool2d_nchw_impl<ctype>( \
113113
input, kernel_size, stride, padding, dilation, ceil_mode, output); \
114114
break; \
115115
}
116116

117117
ScalarType dtype = input.scalar_type();
118118
// NOLINTBEGIN(clang-diagnostic-switch-enum)
119119
switch (dtype) {
120-
ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d)
120+
ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d_nchw)
121121
default:
122122
ET_DCHECK_MSG(
123123
false, "Unhandled dtype %s", torch::executor::toString(dtype));
124124
}
125125
// NOLINTEND(clang-diagnostic-switch-enum)
126126

127-
#undef typed_quantized_max_pool2d
127+
#undef typed_quantized_max_pool2d_nchw
128128
return output;
129129
}
130130

backends/cadence/generic/operators/op_quantized_max_pool2d.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace impl {
1515
namespace generic {
1616
namespace native {
1717

18-
::executorch::aten::Tensor& quantized_max_pool2d_out(
18+
::executorch::aten::Tensor& quantized_max_pool2d_nchw_out(
1919
::executorch::runtime::KernelRuntimeContext& ctx,
2020
const ::executorch::aten::Tensor& input,
2121
::executorch::aten::IntArrayRef kernel_size,

0 commit comments

Comments
 (0)