Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,16 @@
- arg_meta: null
kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out

- func: cadence::quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::generic::quantized_max_pool2d_nchw_out

- func: cadence::quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::generic::quantized_max_pool2d_nhwc_out

- func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
97 changes: 96 additions & 1 deletion backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pyre-strict

from math import prod
from math import ceil, prod
from typing import Callable, Optional, Tuple

import torch
Expand Down Expand Up @@ -213,6 +213,19 @@ def register_fake(
"quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor (a!)"
)

lib.define(
"quantized_max_pool2d_nchw(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
)
lib.define(
"quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_max_pool2d_nhwc(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
)
lib.define(
"quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
"quantized_conv2d_nhwc(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)"
)
Expand Down Expand Up @@ -2270,6 +2283,88 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta(
return input.new_empty(input.size(), dtype=input.dtype)


@register_fake("cadence::quantized_max_pool2d_nchw")
def quantized_max_pool2d_nchw_meta(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
padding: list[int],
dilation: list[int],
ceil_mode: bool,
) -> torch.Tensor:
assert (
len(kernel_size) == 2
), f"kernel_size must have 2 elements, got {len(kernel_size)}"
assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}"
assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}"
assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}"
assert (
len(input.size()) == 4
), f"input must be 4D (N, C, H, W), got {len(input.size())}D"

batch = input.size(0)
channels = input.size(1)
height_in = input.size(2)
width_in = input.size(3)

height_out_raw = (
height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
) / stride[0] + 1
width_out_raw = (
width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
) / stride[1] + 1

if ceil_mode:
height_out = ceil(height_out_raw)
width_out = ceil(width_out_raw)
else:
height_out = int(height_out_raw)
width_out = int(width_out_raw)

return input.new_empty([batch, channels, height_out, width_out], dtype=input.dtype)


@register_fake("cadence::quantized_max_pool2d_nhwc")
def quantized_max_pool2d_nhwc_meta(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
padding: list[int],
dilation: list[int],
ceil_mode: bool,
) -> torch.Tensor:
assert (
len(kernel_size) == 2
), f"kernel_size must have 2 elements, got {len(kernel_size)}"
assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}"
assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}"
assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}"
assert (
len(input.size()) == 4
), f"input must be 4D (N, H, W, C), got {len(input.size())}D"

batch = input.size(0)
height_in = input.size(1)
width_in = input.size(2)
channels = input.size(3)

height_out_raw = (
height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
) / stride[0] + 1
width_out_raw = (
width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
) / stride[1] + 1

if ceil_mode:
height_out = ceil(height_out_raw)
width_out = ceil(width_out_raw)
else:
height_out = int(height_out_raw)
width_out = int(width_out_raw)

return input.new_empty([batch, height_out, width_out, channels], dtype=input.dtype)


@register_fake("cadence::fully_connected")
def fully_connected_meta(
src: torch.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from executorch.backends.cadence.aot.remove_ops import (
CadenceRemoveNops,
RemoveNopSliceOrViewOpPass,
RemovePermutesAroundElementwiseOps,
RemoveRedundantOps,
)
from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
Expand Down Expand Up @@ -89,6 +90,7 @@ def get_passes_in_default_order() -> list[Type[ExportPass]]:
CadenceSimplifyOpsInGraph.passes,
FinalizePipeline,
FuseFullThenReshapePass,
RemovePermutesAroundElementwiseOps,
FuseTransposeOrPermuteOpPairsPass,
RemoveNopSliceOrViewOpPass,
CompileTimeTypeDispatchPass,
Expand Down
51 changes: 50 additions & 1 deletion backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

# pyre-strict

from typing import Any, cast, Dict, List, Tuple
import operator as op_module
from typing import Any, cast, Dict, List, Optional, Tuple

import torch
from executorch.backends.cadence.aot.compiler_utils import get_shape
from executorch.backends.cadence.aot.pass_utils import get_arg
from executorch.backends.cadence.aot.quantizer.patterns import (
AddmmPattern,
AddPattern,
Expand All @@ -24,6 +26,8 @@
LayerNormPattern,
LinearPattern,
MatmulPattern,
MaxPool2dPattern,
MaxPool2dWithoutIndicesPattern,
MixedW8A32ConvPattern,
MixedW8A32GruPattern,
MixedW8A32LinearPattern,
Expand Down Expand Up @@ -457,6 +461,34 @@ def get_args_and_kwargs_mixed_w8a32_conv(
return args, kwargs


def get_args_and_kwargs_max_pool2d(
inputs_inputs: List[fx.Node],
op_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
"""
Returns the args and kwargs for the max_pool2d replacement op.

Max pooling is order-preserving, so we can perform the max operation
directly on quantized values without any requantization.
"""
# Get the pooling parameters from the original op node using get_arg
kernel_size = get_arg(op_node, "kernel_size", Optional[list[int]]) or [1, 1]
stride = get_arg(op_node, "stride", Optional[list[int]]) or kernel_size
padding = get_arg(op_node, "padding", Optional[list[int]]) or [0, 0]
dilation = get_arg(op_node, "dilation", Optional[list[int]]) or [1, 1]
ceil_mode = get_arg(op_node, "ceil_mode", Optional[bool]) or False

args = (inputs_inputs[0],)
kwargs = {
"kernel_size": kernel_size,
"stride": stride,
"padding": padding,
"dilation": dilation,
"ceil_mode": ceil_mode,
}
return args, kwargs


def get_args_and_kwargs_mixed_w8a32_gru(
graph_module: GraphModule,
other_inputs: List[fx.Node],
Expand Down Expand Up @@ -549,6 +581,16 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901

assert op_node is not None, "op_node is None"
quant_node = list(op_node.users.keys())[0]
# For ops that return tuples (e.g., max_pool2d_with_indices),
# traverse through the getitem to find the actual quant node
if quant_node.target is op_module.getitem:
assert (
len(quant_node.args) >= 2 and quant_node.args[1] == 0
), f"Expected getitem[0] for the values output, but got getitem[{quant_node.args[1] if len(quant_node.args) >= 2 else '?'}]"
assert (
len(list(quant_node.users.keys())) > 0
), "getitem node has no users"
quant_node = list(quant_node.users.keys())[0]

with graph_module.graph.inserting_after(op_node):
args = tuple(
Expand Down Expand Up @@ -697,6 +739,13 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
dequants_biases,
op_node,
)
elif isinstance(
pattern, (MaxPool2dPattern, MaxPool2dWithoutIndicesPattern)
):
args, kwargs = get_args_and_kwargs_max_pool2d(
inputs_inputs,
op_node,
)

fused = graph_module.graph.call_function(
pattern.replacement_op(),
Expand Down
88 changes: 88 additions & 0 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,95 @@ def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_matmul.default


class MaxPool2dPattern(QuantizationPattern):
"""
Pattern for quantized max pooling (with indices variant).

Max pooling is order-preserving, so max(a, b) in the quantized domain gives
the same result as quantizing max(dequant(a), dequant(b)) when using the same
scale/zero_point. This means we can perform max pooling directly on quantized
values without any requantization.

The input and output share quantization parameters.
"""

def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.max_pool2d_with_indices.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
max_pool_node = fused_partition[0].nodes[-1]

# Input and output share quantization parameters since max is order-preserving
return (
PartitionAnchors(
inputs=[(max_pool_node, 0)],
weights=[],
biases=[],
# kernel_size, stride, padding, dilation, ceil_mode are literals
literals=[
(max_pool_node, i) for i in range(1, len(max_pool_node.args))
],
output=[
(
max_pool_node,
SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)),
)
],
),
max_pool_node,
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_max_pool2d_nchw.default


class MaxPool2dWithoutIndicesPattern(QuantizationPattern):
"""
Pattern for quantized max pooling (without indices variant).

Same as MaxPool2dPattern but matches aten.max_pool2d.default which returns
a single tensor instead of a tuple (values, indices).
"""

def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.max_pool2d.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
max_pool_node = fused_partition[0].nodes[-1]

return (
PartitionAnchors(
inputs=[(max_pool_node, 0)],
weights=[],
biases=[],
literals=[
(max_pool_node, i) for i in range(1, len(max_pool_node.args))
],
output=[
(
max_pool_node,
SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)),
)
],
),
max_pool_node,
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_max_pool2d_nchw.default


# This is a base class for ReLU

# This is a base class for ReLU, since it can be used with two different aten ops


class ReluBasePattern(QuantizationPattern):
@abstractmethod
def partition_types(self) -> List[OpOverload]:
Expand Down
4 changes: 4 additions & 0 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
LayerNormPattern,
LinearPattern,
MatmulPattern,
MaxPool2dPattern,
MaxPool2dWithoutIndicesPattern,
MixedW8A32ConvPattern,
MixedW8A32GruPattern,
MixedW8A32LinearPattern,
Expand Down Expand Up @@ -227,6 +229,8 @@ def get_cadence_default_quantizers() -> List[Quantizer]:
CadenceAtenQuantizer(Conv2dPattern(), qconfig_A8W8sym),
CadenceAtenQuantizer(LinearPattern(), qconfig_A8W8),
CadenceAtenQuantizer(MatmulPattern(), qconfig_A8W8),
CadenceAtenQuantizer(MaxPool2dPattern(), qconfig_A8W8),
CadenceAtenQuantizer(MaxPool2dWithoutIndicesPattern(), qconfig_A8W8),
CadenceAtenQuantizer(ReluPattern0(), qconfig_A8W8),
CadenceAtenQuantizer(ReluPattern1(), qconfig_A8W8),
]
Expand Down
Loading
Loading