diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 528ceadaf19..1c4dd3e06f3 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -309,6 +309,16 @@ - arg_meta: null kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out +- func: cadence::quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::generic::quantized_max_pool2d_nchw_out + +- func: cadence::quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::generic::quantized_max_pool2d_nhwc_out + - func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index cbc179e05d2..060702becec 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -6,7 +6,7 @@ # pyre-strict -from math import prod +from math import ceil, prod from typing import Callable, Optional, Tuple import torch @@ -213,6 +213,19 @@ def register_fake( "quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor (a!)" ) +lib.define( + "quantized_max_pool2d_nchw(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor" +) +lib.define( + "quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_max_pool2d_nhwc(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor" +) +lib.define( + "quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)" +) + lib.define( "quantized_conv2d_nhwc(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)" ) @@ -2270,6 +2283,88 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta( return input.new_empty(input.size(), dtype=input.dtype) +@register_fake("cadence::quantized_max_pool2d_nchw") +def quantized_max_pool2d_nchw_meta( + input: torch.Tensor, + kernel_size: list[int], + stride: list[int], + padding: list[int], + dilation: list[int], + ceil_mode: bool, +) -> torch.Tensor: + assert ( + len(kernel_size) == 2 + ), f"kernel_size must have 2 elements, got {len(kernel_size)}" + assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}" + assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}" + assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}" + assert ( + len(input.size()) == 4 + ), f"input must be 4D (N, C, H, W), got {len(input.size())}D" + + batch = input.size(0) + channels = input.size(1) + height_in = input.size(2) + width_in = input.size(3) + + height_out_raw = ( + height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1 + ) / stride[0] + 1 + width_out_raw = ( + width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1 + ) / stride[1] + 1 + + if ceil_mode: + height_out = ceil(height_out_raw) + width_out = ceil(width_out_raw) + else: + height_out = int(height_out_raw) + width_out = int(width_out_raw) + + return input.new_empty([batch, channels, height_out, width_out], dtype=input.dtype) + + +@register_fake("cadence::quantized_max_pool2d_nhwc") +def quantized_max_pool2d_nhwc_meta( + input: torch.Tensor, + kernel_size: list[int], + stride: list[int], + padding: list[int], + dilation: list[int], + ceil_mode: bool, +) -> torch.Tensor: + assert ( + len(kernel_size) == 2 + ), f"kernel_size must have 2 elements, got {len(kernel_size)}" + assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}" + assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}" + assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}" + assert ( + len(input.size()) == 4 + ), f"input must be 4D (N, H, W, C), got {len(input.size())}D" + + batch = input.size(0) + height_in = input.size(1) + width_in = input.size(2) + channels = input.size(3) + + height_out_raw = ( + height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1 + ) / stride[0] + 1 + width_out_raw = ( + width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1 + ) / stride[1] + 1 + + if ceil_mode: + height_out = ceil(height_out_raw) + width_out = ceil(width_out_raw) + else: + height_out = int(height_out_raw) + width_out = int(width_out_raw) + + return input.new_empty([batch, height_out, width_out, channels], dtype=input.dtype) + + @register_fake("cadence::fully_connected") def fully_connected_meta( src: torch.Tensor, diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 0cdda1ad3bc..ae4d42f4898 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -6,10 +6,12 @@ # pyre-strict -from typing import Any, cast, Dict, List, Tuple +import operator as op_module +from typing import Any, cast, Dict, List, Optional, Tuple import torch from executorch.backends.cadence.aot.compiler_utils import get_shape +from executorch.backends.cadence.aot.pass_utils import get_arg from executorch.backends.cadence.aot.quantizer.patterns import ( AddmmPattern, AddPattern, @@ -24,6 +26,8 @@ LayerNormPattern, LinearPattern, MatmulPattern, + MaxPool2dPattern, + MaxPool2dWithoutIndicesPattern, MixedW8A32ConvPattern, MixedW8A32GruPattern, MixedW8A32LinearPattern, @@ -457,6 +461,34 @@ def get_args_and_kwargs_mixed_w8a32_conv( return args, kwargs +def get_args_and_kwargs_max_pool2d( + inputs_inputs: List[fx.Node], + op_node: fx.Node, +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: + """ + Returns the args and kwargs for the max_pool2d replacement op. + + Max pooling is order-preserving, so we can perform the max operation + directly on quantized values without any requantization. + """ + # Get the pooling parameters from the original op node using get_arg + kernel_size = get_arg(op_node, "kernel_size", Optional[list[int]]) or [1, 1] + stride = get_arg(op_node, "stride", Optional[list[int]]) or kernel_size + padding = get_arg(op_node, "padding", Optional[list[int]]) or [0, 0] + dilation = get_arg(op_node, "dilation", Optional[list[int]]) or [1, 1] + ceil_mode = get_arg(op_node, "ceil_mode", Optional[bool]) or False + + args = (inputs_inputs[0],) + kwargs = { + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + "ceil_mode": ceil_mode, + } + return args, kwargs + + def get_args_and_kwargs_mixed_w8a32_gru( graph_module: GraphModule, other_inputs: List[fx.Node], @@ -549,6 +581,16 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 assert op_node is not None, "op_node is None" quant_node = list(op_node.users.keys())[0] + # For ops that return tuples (e.g., max_pool2d_with_indices), + # traverse through the getitem to find the actual quant node + if quant_node.target is op_module.getitem: + assert ( + len(quant_node.args) >= 2 and quant_node.args[1] == 0 + ), f"Expected getitem[0] for the values output, but got getitem[{quant_node.args[1] if len(quant_node.args) >= 2 else '?'}]" + assert ( + len(list(quant_node.users.keys())) > 0 + ), "getitem node has no users" + quant_node = list(quant_node.users.keys())[0] with graph_module.graph.inserting_after(op_node): args = tuple( @@ -697,6 +739,13 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 dequants_biases, op_node, ) + elif isinstance( + pattern, (MaxPool2dPattern, MaxPool2dWithoutIndicesPattern) + ): + args, kwargs = get_args_and_kwargs_max_pool2d( + inputs_inputs, + op_node, + ) fused = graph_module.graph.call_function( pattern.replacement_op(), diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 7a11541b601..f46a5bdcf56 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -417,7 +417,95 @@ def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_matmul.default +class MaxPool2dPattern(QuantizationPattern): + """ + Pattern for quantized max pooling (with indices variant). + + Max pooling is order-preserving, so max(a, b) in the quantized domain gives + the same result as quantizing max(dequant(a), dequant(b)) when using the same + scale/zero_point. This means we can perform max pooling directly on quantized + values without any requantization. + + The input and output share quantization parameters. + """ + + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.max_pool2d_with_indices.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + max_pool_node = fused_partition[0].nodes[-1] + + # Input and output share quantization parameters since max is order-preserving + return ( + PartitionAnchors( + inputs=[(max_pool_node, 0)], + weights=[], + biases=[], + # kernel_size, stride, padding, dilation, ceil_mode are literals + literals=[ + (max_pool_node, i) for i in range(1, len(max_pool_node.args)) + ], + output=[ + ( + max_pool_node, + SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)), + ) + ], + ), + max_pool_node, + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_max_pool2d_nchw.default + + +class MaxPool2dWithoutIndicesPattern(QuantizationPattern): + """ + Pattern for quantized max pooling (without indices variant). + + Same as MaxPool2dPattern but matches aten.max_pool2d.default which returns + a single tensor instead of a tuple (values, indices). + """ + + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.max_pool2d.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + max_pool_node = fused_partition[0].nodes[-1] + + return ( + PartitionAnchors( + inputs=[(max_pool_node, 0)], + weights=[], + biases=[], + literals=[ + (max_pool_node, i) for i in range(1, len(max_pool_node.args)) + ], + output=[ + ( + max_pool_node, + SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)), + ) + ], + ), + max_pool_node, + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_max_pool2d_nchw.default + + +# This is a base class for ReLU + # This is a base class for ReLU, since it can be used with two different aten ops + + class ReluBasePattern(QuantizationPattern): @abstractmethod def partition_types(self) -> List[OpOverload]: diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index bdd4cc810a0..9399efe632a 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -24,6 +24,8 @@ LayerNormPattern, LinearPattern, MatmulPattern, + MaxPool2dPattern, + MaxPool2dWithoutIndicesPattern, MixedW8A32ConvPattern, MixedW8A32GruPattern, MixedW8A32LinearPattern, @@ -227,6 +229,8 @@ def get_cadence_default_quantizers() -> List[Quantizer]: CadenceAtenQuantizer(Conv2dPattern(), qconfig_A8W8sym), CadenceAtenQuantizer(LinearPattern(), qconfig_A8W8), CadenceAtenQuantizer(MatmulPattern(), qconfig_A8W8), + CadenceAtenQuantizer(MaxPool2dPattern(), qconfig_A8W8), + CadenceAtenQuantizer(MaxPool2dWithoutIndicesPattern(), qconfig_A8W8), CadenceAtenQuantizer(ReluPattern0(), qconfig_A8W8), CadenceAtenQuantizer(ReluPattern1(), qconfig_A8W8), ] diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 44cae6e55ea..f985718c150 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1868,6 +1868,66 @@ def rms_norm( return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X) +@impl_tracked(m, "quantized_max_pool2d_nchw") +def quantized_max_pool2d_nchw( + input: torch.Tensor, + kernel_size: list[int], + stride: list[int], + padding: list[int], + dilation: list[int], + ceil_mode: bool, +) -> torch.Tensor: + """ + Quantized max pooling operation. + + Max pooling is order-preserving, so max(a, b) in the quantized domain gives + the same result as quantizing max(dequant(a), dequant(b)) when using the same + scale/zero_point. This means we can perform max pooling directly on quantized + integer values without dequantization/requantization. + """ + # Directly apply max_pool2d on quantized values + # Since max is order-preserving, the result is correct without any dequant/requant + return F.max_pool2d( + input, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + + +@impl_tracked(m, "quantized_max_pool2d_nhwc") +def quantized_max_pool2d_nhwc( + input: torch.Tensor, + kernel_size: list[int], + stride: list[int], + padding: list[int], + dilation: list[int], + ceil_mode: bool, +) -> torch.Tensor: + """ + Quantized max pooling in NHWC layout. + + Converts NHWC→NCHW, performs max pooling, then converts back NCHW→NHWC. + """ + # Convert NHWC [N, H, W, C] to NCHW [N, C, H, W] + input_nchw = input.permute(0, 3, 1, 2).contiguous() + + # Call the NCHW version + output_nchw = quantized_max_pool2d_nchw( + input_nchw, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + + # Convert NCHW [N, C, H_out, W_out] back to NHWC [N, H_out, W_out, C] + return output_nchw.permute(0, 2, 3, 1).contiguous() + + @impl_tracked(m, "where_Scalar") def where_Scalar( condition: torch.Tensor, diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 14a35c01baf..6e6e98af267 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -1182,6 +1182,67 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True +@register_cadence_pass(CadencePassAttribute(opt_level=3)) +class ReplaceMaxPool2dWithChannelLastMaxPool2dPass(RemoveOrReplacePassInterface): + """ + Replace NCHW max pooling with NHWC (channel-last) max pooling by adding + permute operations before and after the max pooling. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.cadence.quantized_max_pool2d_nchw.default, + ] + + def _change_nchw_to_nhwc( + self, graph: torch.fx.Graph, node: torch.fx.Node + ) -> torch.fx.Node: + """Convert NCHW format to NHWC format.""" + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, [0, 2, 3, 1]), {} + ) + permute_node.meta = node.meta + return permute_node + + def _change_nhwc_to_nchw( + self, graph: torch.fx.Graph, node: torch.fx.Node + ) -> torch.fx.Node: + """Convert NHWC format to NCHW format.""" + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, [0, 3, 1, 2]), {} + ) + permute_node.meta = node.meta + return permute_node + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + graph = node.graph + + # Get input node + input_node = cast(torch.fx.Node, node.args[0]) + + with graph.inserting_before(node): + # Convert input from NCHW to NHWC + input_nhwc = self._change_nchw_to_nhwc(graph, input_node) + + # Create the NHWC max pooling with the same args (kernel_size, stride, padding, dilation, ceil_mode) + new_args = (input_nhwc,) + tuple(node.args[1:]) + + new_pool = graph.call_function( + exir_ops.edge.cadence.quantized_max_pool2d_nhwc.default, + new_args, + node.kwargs, + ) + new_pool.meta = node.meta + + # Convert output back from NHWC to NCHW + nchw_output = self._change_nhwc_to_nchw(graph, new_pool) + + # Replace all uses with the final output + node.replace_all_uses_with(nchw_output) + return True + + @register_cadence_pass(CadencePassAttribute(opt_level=3)) class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface): """ @@ -2561,6 +2622,7 @@ class CadenceReplaceOpsInGraph: ReplacePadWithCatPass, ReplaceConstantPadNdWithSlicePass, ReplaceConvWithChannelLastConvPass, + ReplaceMaxPool2dWithChannelLastMaxPool2dPass, ReplaceTrivialConvWithLinear, ReplaceConvWithIm2RowAndLinear, ReplaceTransposedConvWithLinearPass, diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index 831ab3b95b6..e8061bb266c 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -42,6 +42,7 @@ Q_ANNOTATION_KEY, QuantizationAnnotation, QuantizationSpec, + SharedQuantizationSpec, ) # Type alias for graph builder functions. @@ -206,7 +207,26 @@ # Use None to skip comparison for bias since it's a DerivedQuantizationSpec [None, qconfig_A8W8.input_activation, qconfig_A8W8.weight], ), + ( + "default_max_pool2d_A8W8", + lambda self: self._build_max_pool2d_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.max_pool2d_with_indices.default, + # Output uses SharedQuantizationSpec (shares qparams with input) + SharedQuantizationSpec, + # For max_pool2d: only input_activation (no weights, order-preserving op) + [qconfig_A8W8.input_activation], + ), # CadenceFusedConvReluQuantizer test cases + ( + "fused_conv1d_relu_A8W8sym", + lambda self: self._build_conv1d_relu_graph(), + CadenceFusedConvReluQuantizer(), + torch.ops.aten.relu.default, + qconfig_A8W8sym.output_activation, + # For fused conv1d+relu: [input_activation, weight] from conv1d node + [qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight], + ), ( "fused_conv2d_relu_A8W8sym", lambda self: self._build_conv2d_relu_graph(), @@ -457,6 +477,40 @@ def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node") return gm, addmm_nodes[0] + def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a max_pool2d_with_indices operation.""" + builder = GraphBuilder() + # Input shape: (batch, channels, height, width) + x = builder.placeholder("x", torch.randn(1, 3, 8, 8)) + # max_pool2d_with_indices args: (input, kernel_size, stride, padding, dilation, ceil_mode) + max_pool = builder.call_operator( + op=torch.ops.aten.max_pool2d_with_indices.default, + args=(x, [2, 2], [2, 2], [0, 0], [1, 1], False), + meta=NodeMetadata( + { + "source_fn_stack": [ + ( + "max_pool2d_with_indices", + torch.ops.aten.max_pool2d_with_indices.default, + ) + ] + } + ), + ) + builder.output([max_pool]) + gm = builder.get_graph_module() + + max_pool_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.max_pool2d_with_indices.default, + ) + self.assertEqual( + len(max_pool_nodes), + 1, + "Should find exactly one max_pool2d_with_indices node", + ) + return gm, max_pool_nodes[0] + def _build_conv2d_relu_graph( self, ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: @@ -503,6 +557,52 @@ def _build_conv2d_relu_graph( return gm, relu_nodes[0], conv2d_nodes[0] + def _build_conv1d_relu_graph( + self, + ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: + """Build a graph with a conv1d followed by relu (fused pattern). + + Returns: + A tuple of (graph_module, relu_node, conv_node). + The relu_node is the target node where the annotation is placed. + The conv_node is the input source node whose args contain the quantized inputs. + """ + builder = GraphBuilder() + # Input shape: (batch, in_channels, length) + x = builder.placeholder("x", torch.randn(1, 3, 10)) + # Weight shape: (out_channels, in_channels, kernel_size) + weight = builder.placeholder("weight", torch.randn(6, 3, 3)) + conv1d = builder.call_operator( + op=torch.ops.aten.conv1d.default, + args=(x, weight), + meta=NodeMetadata( + {"source_fn_stack": [("conv1d", torch.ops.aten.conv1d.default)]} + ), + ) + relu = builder.call_operator( + op=torch.ops.aten.relu.default, + args=(conv1d,), + meta=NodeMetadata( + {"source_fn_stack": [("relu", torch.ops.aten.relu.default)]} + ), + ) + builder.output([relu]) + gm = builder.get_graph_module() + + relu_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.relu.default, + ) + self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") + + conv1d_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.conv1d.default, + ) + self.assertEqual(len(conv1d_nodes), 1, "Should find exactly one conv1d node") + + return gm, relu_nodes[0], conv1d_nodes[0] + @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) def test_quantizer_annotation( self, @@ -532,7 +632,13 @@ def test_quantizer_annotation( # Verify output annotation (always on the output node) output_annotation: QuantizationAnnotation = output_node.meta[Q_ANNOTATION_KEY] self.assertTrue(output_annotation._annotated) - self.assertEqual(output_annotation.output_qspec, expected_output_qspec) + if isinstance(expected_output_qspec, type) and issubclass( + expected_output_qspec, SharedQuantizationSpec + ): + # For order-preserving ops like max_pool2d, verify output uses SharedQuantizationSpec + self.assertIsInstance(output_annotation.output_qspec, expected_output_qspec) + else: + self.assertEqual(output_annotation.output_qspec, expected_output_qspec) # Verify input annotations (on the input source node, which may differ for fused patterns) input_annotation: QuantizationAnnotation = input_source_node.meta[ @@ -608,6 +714,8 @@ def test_default_quantizer_ops_to_preserve(self) -> None: torch.ops.aten.conv2d.default, torch.ops.aten.linear.default, torch.ops.aten.matmul.default, + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.max_pool2d.default, torch.ops.aten.relu.default, torch.ops.aten.relu_.default, ] diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 95d470644a0..5d9f8c0784b 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -36,6 +36,7 @@ ReplaceLinearWithFullyConnectedOpPass, ReplaceLogicalNotBooleanWhereWithWherePass, ReplaceMatmulWithTransposedMatmulPass, + ReplaceMaxPool2dWithChannelLastMaxPool2dPass, ReplaceMMWithAddMMPass, ReplaceMulTensorWithMulAndFullOpsPass, ReplaceNopTransposeOrPermuteWithViewPass, @@ -2586,6 +2587,59 @@ def test_cat_insert_transpose(self) -> None: ) +class TestReplaceMaxPool2dWithChannelLastMaxPool2dPass(unittest.TestCase): + def test_replace_max_pool2d_nchw_with_nhwc(self) -> None: + # Create a graph with a single quantized_max_pool2d_nchw node. + x = torch.randint(0, 100, (1, 3, 8, 8), dtype=torch.int8) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.cadence.quantized_max_pool2d_nchw.default, + args=(x, [2, 2], [2, 2], [0, 0], [1, 1], False), + ) + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_max_pool2d_nchw.default), 1 + ) + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) + + # Deepcopy before the pass + original = copy.deepcopy(gm) + + # Apply replacement pass. + p = ReplaceMaxPool2dWithChannelLastMaxPool2dPass() + result = p.call(gm) + self.assertTrue(result.modified) + gm_after_replacement = result.graph_module + + # Check that replacement was made. + self.assertEqual( + count_node( + gm_after_replacement, + exir_ops.edge.cadence.quantized_max_pool2d_nhwc.default, + ), + 1, + ) + self.assertEqual( + count_node( + gm_after_replacement, + exir_ops.edge.cadence.quantized_max_pool2d_nchw.default, + ), + 0, + ) + # Two permutes: one for input NCHW->NHWC, one for output NHWC->NCHW + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), + 2, + ) + + # Validate numerical accuracy + validate( + original, + gm_after_replacement, + (x,), + "ReplaceMaxPool2dWithChannelLastMaxPool2dPass", + ) + + class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase): def _get_slice_empty_gm(self) -> tuple[torch.fx.GraphModule, torch.Tensor]: builder = GraphBuilder() diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp b/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp new file mode 100644 index 00000000000..f843ad84080 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +namespace { + +template +void quantized_max_pool2d_nchw_impl( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + ET_UNUSED bool ceil_mode, + Tensor& output) { + const T* __restrict__ in_data = input.const_data_ptr(); + T* __restrict__ out_data = output.mutable_data_ptr(); + + // Input dimensions: [N, C, H, W] + const int64_t batch_size = input.size(0); + const int64_t channels = input.size(1); + const int64_t in_height = input.size(2); + const int64_t in_width = input.size(3); + + // Output dimensions + const int64_t out_height = output.size(2); + const int64_t out_width = output.size(3); + + // Pooling parameters + const int64_t kernel_h = kernel_size[0]; + const int64_t kernel_w = kernel_size[1]; + const int64_t stride_h = stride[0]; + const int64_t stride_w = stride[1]; + const int64_t pad_h = padding[0]; + const int64_t pad_w = padding[1]; + const int64_t dilation_h = dilation[0]; + const int64_t dilation_w = dilation[1]; + + // Iterate over batch and channels + for (int64_t n = 0; n < batch_size; ++n) { + for (int64_t c = 0; c < channels; ++c) { + // Iterate over output spatial dimensions + for (int64_t oh = 0; oh < out_height; ++oh) { + for (int64_t ow = 0; ow < out_width; ++ow) { + // Compute the input region for this output pixel + const int64_t ih_start = oh * stride_h - pad_h; + const int64_t iw_start = ow * stride_w - pad_w; + + // Initialize with minimum value for the type + T max_val = std::numeric_limits::lowest(); + + // Iterate over the kernel + for (int64_t kh = 0; kh < kernel_h; ++kh) { + for (int64_t kw = 0; kw < kernel_w; ++kw) { + const int64_t ih = ih_start + kh * dilation_h; + const int64_t iw = iw_start + kw * dilation_w; + + // Check bounds + if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) { + const int64_t in_idx = + ((n * channels + c) * in_height + ih) * in_width + iw; + max_val = std::max(max_val, in_data[in_idx]); + } + } + } + + // Write output + const int64_t out_idx = + ((n * channels + c) * out_height + oh) * out_width + ow; + out_data[out_idx] = max_val; + } + } + } + } +} + +} // namespace + +Tensor& quantized_max_pool2d_nchw_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + Tensor& output) { +#define typed_quantized_max_pool2d_nchw(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_max_pool2d_nchw_impl( \ + input, kernel_size, stride, padding, dilation, ceil_mode, output); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + // NOLINTBEGIN(clang-diagnostic-switch-enum) + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d_nchw) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + // NOLINTEND(clang-diagnostic-switch-enum) + +#undef typed_quantized_max_pool2d_nchw + return output; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d.h b/backends/cadence/generic/operators/op_quantized_max_pool2d.h new file mode 100644 index 00000000000..453dd5a2582 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_max_pool2d_nchw_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + ::executorch::aten::IntArrayRef kernel_size, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef dilation, + bool ceil_mode, + ::executorch::aten::Tensor& output); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp new file mode 100644 index 00000000000..d8f0d9e068b --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp @@ -0,0 +1,136 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +namespace { + +template +void quantized_max_pool2d_nhwc_impl( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + ET_UNUSED bool ceil_mode, + Tensor& output) { + const T* __restrict__ in_data = input.const_data_ptr(); + T* __restrict__ out_data = output.mutable_data_ptr(); + + // Input dimensions: [N, H, W, C] + const int64_t batch_size = input.size(0); + const int64_t in_height = input.size(1); + const int64_t in_width = input.size(2); + const int64_t channels = input.size(3); + + // Output dimensions: [N, H_out, W_out, C] + const int64_t out_height = output.size(1); + const int64_t out_width = output.size(2); + + // Pooling parameters + const int64_t kernel_h = kernel_size[0]; + const int64_t kernel_w = kernel_size[1]; + const int64_t stride_h = stride[0]; + const int64_t stride_w = stride[1]; + const int64_t pad_h = padding[0]; + const int64_t pad_w = padding[1]; + const int64_t dilation_h = dilation[0]; + const int64_t dilation_w = dilation[1]; + + for (int64_t n = 0; n < batch_size; ++n) { + for (int64_t oh = 0; oh < out_height; ++oh) { + for (int64_t ow = 0; ow < out_width; ++ow) { + const int64_t ih_start = oh * stride_h - pad_h; + const int64_t iw_start = ow * stride_w - pad_w; + + T* __restrict__ out_ptr = + out_data + ((n * out_height + oh) * out_width + ow) * channels; + + // Initialize all channels to the minimum value. + for (int64_t c = 0; c < channels; ++c) { + out_ptr[c] = std::numeric_limits::lowest(); + } + + // For each kernel position, compute element-wise max across all + // channels. The inner loop over channels is a stride-1 contiguous + // access in NHWC layout, enabling SIMD auto-vectorization. + for (int64_t kh = 0; kh < kernel_h; ++kh) { + const int64_t ih = ih_start + kh * dilation_h; + if (ih < 0 || ih >= in_height) { + continue; + } + for (int64_t kw = 0; kw < kernel_w; ++kw) { + const int64_t iw = iw_start + kw * dilation_w; + if (iw < 0 || iw >= in_width) { + continue; + } + + const T* __restrict__ in_ptr = + in_data + ((n * in_height + ih) * in_width + iw) * channels; + + for (int64_t c = 0; c < channels; ++c) { + out_ptr[c] = std::max(out_ptr[c], in_ptr[c]); + } + } + } + } + } + } +} + +} // namespace + +Tensor& quantized_max_pool2d_nhwc_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + Tensor& output) { +#define typed_quantized_max_pool2d_nhwc(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_max_pool2d_nhwc_impl( \ + input, kernel_size, stride, padding, dilation, ceil_mode, output); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + // NOLINTBEGIN(clang-diagnostic-switch-enum) + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d_nhwc) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + // NOLINTEND(clang-diagnostic-switch-enum) + +#undef typed_quantized_max_pool2d_nhwc + return output; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.h b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.h new file mode 100644 index 00000000000..2b0c02e4bb7 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_max_pool2d_nhwc_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + ::executorch::aten::IntArrayRef kernel_size, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef dilation, + bool ceil_mode, + ::executorch::aten::Tensor& output); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/targets.bzl b/backends/cadence/generic/operators/targets.bzl index faa63e4f46f..fa6708a188e 100644 --- a/backends/cadence/generic/operators/targets.bzl +++ b/backends/cadence/generic/operators/targets.bzl @@ -213,6 +213,30 @@ def define_common_targets(): visibility = ["PUBLIC"], ) + runtime.cxx_library( + name = "op_quantized_max_pool2d", + srcs = ["op_quantized_max_pool2d.cpp"], + exported_headers = ["op_quantized_max_pool2d.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ":cadence_type_util", + ], + visibility = ["PUBLIC"], + ) + + runtime.cxx_library( + name = "op_quantized_max_pool2d_nhwc", + srcs = ["op_quantized_max_pool2d_nhwc.cpp"], + exported_headers = ["op_quantized_max_pool2d_nhwc.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ":cadence_type_util", + ], + visibility = ["PUBLIC"], + ) + runtime.cxx_library( name = "op_quantized_matmul", srcs = ["op_quantized_matmul.cpp"],