Skip to content

Commit 568421c

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add quantized max_pool2d operator (#18202)
Summary: As titled. Useful especially for resnet backbones, to reduce quant/dequant pressure. Reviewed By: DrJessop Differential Revision: D96683989
1 parent 75c85e7 commit 568421c

File tree

10 files changed

+505
-3
lines changed

10 files changed

+505
-3
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,11 @@
309309
- arg_meta: null
310310
kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out
311311

312+
- func: cadence::quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
313+
kernels:
314+
- arg_meta: null
315+
kernel_name: impl::generic::quantized_max_pool2d_out
316+
312317
- func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
313318
kernels:
314319
- arg_meta: null

backends/cadence/aot/ops_registrations.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from math import prod
9+
from math import ceil, prod
1010
from typing import Callable, Optional, Tuple
1111

1212
import torch
@@ -213,6 +213,13 @@ def register_fake(
213213
"quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor (a!)"
214214
)
215215

216+
lib.define(
217+
"quantized_max_pool2d(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
218+
)
219+
lib.define(
220+
"quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
221+
)
222+
216223
lib.define(
217224
"quantized_conv2d_nhwc(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)"
218225
)
@@ -2270,6 +2277,47 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta(
22702277
return input.new_empty(input.size(), dtype=input.dtype)
22712278

22722279

2280+
@register_fake("cadence::quantized_max_pool2d")
2281+
def quantized_max_pool2d_meta(
2282+
input: torch.Tensor,
2283+
kernel_size: list[int],
2284+
stride: list[int],
2285+
padding: list[int],
2286+
dilation: list[int],
2287+
ceil_mode: bool,
2288+
) -> torch.Tensor:
2289+
assert (
2290+
len(kernel_size) == 2
2291+
), f"kernel_size must have 2 elements, got {len(kernel_size)}"
2292+
assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}"
2293+
assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}"
2294+
assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}"
2295+
assert (
2296+
len(input.size()) == 4
2297+
), f"input must be 4D (N, C, H, W), got {len(input.size())}D"
2298+
2299+
batch = input.size(0)
2300+
channels = input.size(1)
2301+
height_in = input.size(2)
2302+
width_in = input.size(3)
2303+
2304+
height_out_raw = (
2305+
height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
2306+
) / stride[0] + 1
2307+
width_out_raw = (
2308+
width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
2309+
) / stride[1] + 1
2310+
2311+
if ceil_mode:
2312+
height_out = ceil(height_out_raw)
2313+
width_out = ceil(width_out_raw)
2314+
else:
2315+
height_out = int(height_out_raw)
2316+
width_out = int(width_out_raw)
2317+
2318+
return input.new_empty([batch, channels, height_out, width_out], dtype=input.dtype)
2319+
2320+
22732321
@register_fake("cadence::fully_connected")
22742322
def fully_connected_meta(
22752323
src: torch.Tensor,

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66

77
# pyre-strict
88

9-
from typing import Any, cast, Dict, List, Tuple
9+
import operator as op_module
10+
from typing import Any, cast, Dict, List, Optional, Tuple
1011

1112
import torch
1213
from executorch.backends.cadence.aot.compiler_utils import get_shape
14+
from executorch.backends.cadence.aot.pass_utils import get_arg
1315
from executorch.backends.cadence.aot.quantizer.patterns import (
1416
AddmmPattern,
1517
AddPattern,
@@ -24,6 +26,8 @@
2426
LayerNormPattern,
2527
LinearPattern,
2628
MatmulPattern,
29+
MaxPool2dPattern,
30+
MaxPool2dWithoutIndicesPattern,
2731
MixedW8A32ConvPattern,
2832
MixedW8A32GruPattern,
2933
MixedW8A32LinearPattern,
@@ -457,6 +461,34 @@ def get_args_and_kwargs_mixed_w8a32_conv(
457461
return args, kwargs
458462

459463

464+
def get_args_and_kwargs_max_pool2d(
465+
inputs_inputs: List[fx.Node],
466+
op_node: fx.Node,
467+
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
468+
"""
469+
Returns the args and kwargs for the max_pool2d replacement op.
470+
471+
Max pooling is order-preserving, so we can perform the max operation
472+
directly on quantized values without any requantization.
473+
"""
474+
# Get the pooling parameters from the original op node using get_arg
475+
kernel_size = get_arg(op_node, "kernel_size", Optional[list[int]]) or [1, 1]
476+
stride = get_arg(op_node, "stride", Optional[list[int]]) or kernel_size
477+
padding = get_arg(op_node, "padding", Optional[list[int]]) or [0, 0]
478+
dilation = get_arg(op_node, "dilation", Optional[list[int]]) or [1, 1]
479+
ceil_mode = get_arg(op_node, "ceil_mode", Optional[bool]) or False
480+
481+
args = (inputs_inputs[0],)
482+
kwargs = {
483+
"kernel_size": kernel_size,
484+
"stride": stride,
485+
"padding": padding,
486+
"dilation": dilation,
487+
"ceil_mode": ceil_mode,
488+
}
489+
return args, kwargs
490+
491+
460492
def get_args_and_kwargs_mixed_w8a32_gru(
461493
graph_module: GraphModule,
462494
other_inputs: List[fx.Node],
@@ -549,6 +581,16 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
549581

550582
assert op_node is not None, "op_node is None"
551583
quant_node = list(op_node.users.keys())[0]
584+
# For ops that return tuples (e.g., max_pool2d_with_indices),
585+
# traverse through the getitem to find the actual quant node
586+
if quant_node.target is op_module.getitem:
587+
assert (
588+
len(quant_node.args) >= 2 and quant_node.args[1] == 0
589+
), f"Expected getitem[0] for the values output, but got getitem[{quant_node.args[1] if len(quant_node.args) >= 2 else '?'}]"
590+
assert (
591+
len(list(quant_node.users.keys())) > 0
592+
), "getitem node has no users"
593+
quant_node = list(quant_node.users.keys())[0]
552594

553595
with graph_module.graph.inserting_after(op_node):
554596
args = tuple(
@@ -697,6 +739,13 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
697739
dequants_biases,
698740
op_node,
699741
)
742+
elif isinstance(
743+
pattern, (MaxPool2dPattern, MaxPool2dWithoutIndicesPattern)
744+
):
745+
args, kwargs = get_args_and_kwargs_max_pool2d(
746+
inputs_inputs,
747+
op_node,
748+
)
700749

701750
fused = graph_module.graph.call_function(
702751
pattern.replacement_op(),

backends/cadence/aot/quantizer/patterns.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,90 @@ def replacement_op(self) -> OpOverload:
417417
return torch.ops.cadence.quantized_matmul.default
418418

419419

420+
class MaxPool2dPattern(QuantizationPattern):
421+
"""
422+
Pattern for quantized max pooling (with indices variant).
423+
424+
Max pooling is order-preserving, so max(a, b) in the quantized domain gives
425+
the same result as quantizing max(dequant(a), dequant(b)) when using the same
426+
scale/zero_point. This means we can perform max pooling directly on quantized
427+
values without any requantization.
428+
429+
The input and output share quantization parameters.
430+
"""
431+
432+
def partition_types(self) -> List[OpOverload]:
433+
return [torch.ops.aten.max_pool2d_with_indices.default]
434+
435+
def get_anchors(
436+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
437+
) -> Tuple[PartitionAnchors, fx.Node]:
438+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
439+
max_pool_node = fused_partition[0].nodes[-1]
440+
441+
# Input and output share quantization parameters since max is order-preserving
442+
return (
443+
PartitionAnchors(
444+
inputs=[(max_pool_node, 0)],
445+
weights=[],
446+
biases=[],
447+
# kernel_size, stride, padding, dilation, ceil_mode are literals
448+
literals=[
449+
(max_pool_node, i) for i in range(1, len(max_pool_node.args))
450+
],
451+
output=[
452+
(
453+
max_pool_node,
454+
SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)),
455+
)
456+
],
457+
),
458+
max_pool_node,
459+
)
460+
461+
def replacement_op(self) -> OpOverload:
462+
return torch.ops.cadence.quantized_max_pool2d.default
463+
464+
465+
class MaxPool2dWithoutIndicesPattern(QuantizationPattern):
466+
"""
467+
Pattern for quantized max pooling (without indices variant).
468+
469+
Same as MaxPool2dPattern but matches aten.max_pool2d.default which returns
470+
a single tensor instead of a tuple (values, indices).
471+
"""
472+
473+
def partition_types(self) -> List[OpOverload]:
474+
return [torch.ops.aten.max_pool2d.default]
475+
476+
def get_anchors(
477+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
478+
) -> Tuple[PartitionAnchors, fx.Node]:
479+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
480+
max_pool_node = fused_partition[0].nodes[-1]
481+
482+
return (
483+
PartitionAnchors(
484+
inputs=[(max_pool_node, 0)],
485+
weights=[],
486+
biases=[],
487+
literals=[
488+
(max_pool_node, i) for i in range(1, len(max_pool_node.args))
489+
],
490+
output=[
491+
(
492+
max_pool_node,
493+
SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)),
494+
)
495+
],
496+
),
497+
max_pool_node,
498+
)
499+
500+
def replacement_op(self) -> OpOverload:
501+
return torch.ops.cadence.quantized_max_pool2d.default
502+
503+
420504
# This is a base class for ReLU, since it can be used with two different aten ops
421505
class ReluBasePattern(QuantizationPattern):
422506
@abstractmethod

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
LayerNormPattern,
2525
LinearPattern,
2626
MatmulPattern,
27+
MaxPool2dPattern,
28+
MaxPool2dWithoutIndicesPattern,
2729
MixedW8A32ConvPattern,
2830
MixedW8A32GruPattern,
2931
MixedW8A32LinearPattern,
@@ -227,6 +229,8 @@ def get_cadence_default_quantizers() -> List[Quantizer]:
227229
CadenceAtenQuantizer(Conv2dPattern(), qconfig_A8W8sym),
228230
CadenceAtenQuantizer(LinearPattern(), qconfig_A8W8),
229231
CadenceAtenQuantizer(MatmulPattern(), qconfig_A8W8),
232+
CadenceAtenQuantizer(MaxPool2dPattern(), qconfig_A8W8),
233+
CadenceAtenQuantizer(MaxPool2dWithoutIndicesPattern(), qconfig_A8W8),
230234
CadenceAtenQuantizer(ReluPattern0(), qconfig_A8W8),
231235
CadenceAtenQuantizer(ReluPattern1(), qconfig_A8W8),
232236
]

backends/cadence/aot/ref_implementations.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,6 +1868,35 @@ def rms_norm(
18681868
return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X)
18691869

18701870

1871+
@impl_tracked(m, "quantized_max_pool2d")
1872+
def quantized_max_pool2d(
1873+
input: torch.Tensor,
1874+
kernel_size: list[int],
1875+
stride: list[int],
1876+
padding: list[int],
1877+
dilation: list[int],
1878+
ceil_mode: bool,
1879+
) -> torch.Tensor:
1880+
"""
1881+
Quantized max pooling operation.
1882+
1883+
Max pooling is order-preserving, so max(a, b) in the quantized domain gives
1884+
the same result as quantizing max(dequant(a), dequant(b)) when using the same
1885+
scale/zero_point. This means we can perform max pooling directly on quantized
1886+
integer values without dequantization/requantization.
1887+
"""
1888+
# Directly apply max_pool2d on quantized values
1889+
# Since max is order-preserving, the result is correct without any dequant/requant
1890+
return F.max_pool2d(
1891+
input,
1892+
kernel_size=kernel_size,
1893+
stride=stride,
1894+
padding=padding,
1895+
dilation=dilation,
1896+
ceil_mode=ceil_mode,
1897+
)
1898+
1899+
18711900
@impl_tracked(m, "where_Scalar")
18721901
def where_Scalar(
18731902
condition: torch.Tensor,

0 commit comments

Comments
 (0)