|
6 | 6 |
|
7 | 7 | # pyre-strict |
8 | 8 |
|
9 | | -from typing import Any, cast, Dict, List, Tuple |
| 9 | +import operator as op_module |
| 10 | +from typing import Any, cast, Dict, List, Optional, Tuple |
10 | 11 |
|
11 | 12 | import torch |
12 | 13 | from executorch.backends.cadence.aot.compiler_utils import get_shape |
| 14 | +from executorch.backends.cadence.aot.pass_utils import get_arg |
13 | 15 | from executorch.backends.cadence.aot.quantizer.patterns import ( |
14 | 16 | AddmmPattern, |
15 | 17 | AddPattern, |
|
24 | 26 | LayerNormPattern, |
25 | 27 | LinearPattern, |
26 | 28 | MatmulPattern, |
| 29 | + MaxPool2dPattern, |
| 30 | + MaxPool2dWithoutIndicesPattern, |
27 | 31 | MixedW8A32ConvPattern, |
28 | 32 | MixedW8A32GruPattern, |
29 | 33 | MixedW8A32LinearPattern, |
@@ -457,6 +461,34 @@ def get_args_and_kwargs_mixed_w8a32_conv( |
457 | 461 | return args, kwargs |
458 | 462 |
|
459 | 463 |
|
| 464 | +def get_args_and_kwargs_max_pool2d( |
| 465 | + inputs_inputs: List[fx.Node], |
| 466 | + op_node: fx.Node, |
| 467 | +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: |
| 468 | + """ |
| 469 | + Returns the args and kwargs for the max_pool2d replacement op. |
| 470 | +
|
| 471 | + Max pooling is order-preserving, so we can perform the max operation |
| 472 | + directly on quantized values without any requantization. |
| 473 | + """ |
| 474 | + # Get the pooling parameters from the original op node using get_arg |
| 475 | + kernel_size = get_arg(op_node, "kernel_size", Optional[list[int]]) or [1, 1] |
| 476 | + stride = get_arg(op_node, "stride", Optional[list[int]]) or kernel_size |
| 477 | + padding = get_arg(op_node, "padding", Optional[list[int]]) or [0, 0] |
| 478 | + dilation = get_arg(op_node, "dilation", Optional[list[int]]) or [1, 1] |
| 479 | + ceil_mode = get_arg(op_node, "ceil_mode", Optional[bool]) or False |
| 480 | + |
| 481 | + args = (inputs_inputs[0],) |
| 482 | + kwargs = { |
| 483 | + "kernel_size": kernel_size, |
| 484 | + "stride": stride, |
| 485 | + "padding": padding, |
| 486 | + "dilation": dilation, |
| 487 | + "ceil_mode": ceil_mode, |
| 488 | + } |
| 489 | + return args, kwargs |
| 490 | + |
| 491 | + |
460 | 492 | def get_args_and_kwargs_mixed_w8a32_gru( |
461 | 493 | graph_module: GraphModule, |
462 | 494 | other_inputs: List[fx.Node], |
@@ -549,6 +581,16 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 |
549 | 581 |
|
550 | 582 | assert op_node is not None, "op_node is None" |
551 | 583 | quant_node = list(op_node.users.keys())[0] |
| 584 | + # For ops that return tuples (e.g., max_pool2d_with_indices), |
| 585 | + # traverse through the getitem to find the actual quant node |
| 586 | + if quant_node.target is op_module.getitem: |
| 587 | + assert ( |
| 588 | + len(quant_node.args) >= 2 and quant_node.args[1] == 0 |
| 589 | + ), f"Expected getitem[0] for the values output, but got getitem[{quant_node.args[1] if len(quant_node.args) >= 2 else '?'}]" |
| 590 | + assert ( |
| 591 | + len(list(quant_node.users.keys())) > 0 |
| 592 | + ), "getitem node has no users" |
| 593 | + quant_node = list(quant_node.users.keys())[0] |
552 | 594 |
|
553 | 595 | with graph_module.graph.inserting_after(op_node): |
554 | 596 | args = tuple( |
@@ -697,6 +739,13 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 |
697 | 739 | dequants_biases, |
698 | 740 | op_node, |
699 | 741 | ) |
| 742 | + elif isinstance( |
| 743 | + pattern, (MaxPool2dPattern, MaxPool2dWithoutIndicesPattern) |
| 744 | + ): |
| 745 | + args, kwargs = get_args_and_kwargs_max_pool2d( |
| 746 | + inputs_inputs, |
| 747 | + op_node, |
| 748 | + ) |
700 | 749 |
|
701 | 750 | fused = graph_module.graph.call_function( |
702 | 751 | pattern.replacement_op(), |
|
0 commit comments