diff --git a/backends/arm/operator_support/convolution_support.py b/backends/arm/operator_support/convolution_support.py index 0d0a32200e8..b07ae82f98f 100644 --- a/backends/arm/operator_support/convolution_support.py +++ b/backends/arm/operator_support/convolution_support.py @@ -34,6 +34,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): for pad in output_padding: if pad != 0: + self.reporter.report_reject( + node, "Convolutions with non-zero output padding not implemented." + ) return False # Hardware specific constraints @@ -56,19 +59,33 @@ def _is_node_supported_u55(self, node: fx.Node): # Depthwise convolution for dim in shape_in[1:]: if not 1 <= dim <= 65536: + self.reporter.report_reject( + node, + f"Depthwise convolution must have CWH <= 65536, got {dim})", + ) return False else: # Convolution if not 1 <= C_in <= 65536: + self.reporter.report_reject( + node, f"Convolution must have C <= 65536, got {C_in})" + ) return False kernel_w = kernel[2] kernel_h = kernel[3] if len(kernel) > 3 else 1 # Kernel condition misses constraint on sum of absolute weights if not 1 <= kernel_h <= 64 or not 1 <= kernel_w * kernel_h <= 4096: + self.reporter.report_reject( + node, + f"Convolution needs to have kernel_y<=64, kernel_x*kernel_y<=4096, got kernel ({kernel_w}, {kernel_h})", + ) return False if not self._stride_condition(node): + self.reporter.report_reject( + node, "Failed condition on stride, pad and dilation combination." + ) return False return True diff --git a/backends/arm/operator_support/pool_2d_support.py b/backends/arm/operator_support/pool_2d_support.py index c1dd143a4fc..8291ede8ad9 100644 --- a/backends/arm/operator_support/pool_2d_support.py +++ b/backends/arm/operator_support/pool_2d_support.py @@ -54,12 +54,35 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): if len(node.args) > 3: # Padding case if not all(1 <= k <= 8 for k in kernel): + self.reporter.report_reject( + node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}" + ) return False else: if not kernel_check(kernel): + self.reporter.report_reject( + node, + f"Avgpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}", + ) return False - return dim_check(shape) and shape[0] == 1 and stride_check(stride) + if not dim_check(shape): + self.reporter.report_reject( + node, + f"Avgpool2d needs N == 1, rest dims <= 65536, got shape {list(shape)}", + ) + return False + if not stride_check(stride): + self.reporter.report_reject( + node, f"Avgpool2d needs stride <= 3, got {stride}" + ) + return False + if not shape[0] == 1: + self.reporter.report_reject( + node, f"Avgpool2d needs N==1, got N=={shape[0]}" + ) + return False + return True @register_tosa_support_check @@ -82,4 +105,21 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): kernel = cast(tuple[int, int], node.args[1]) stride = cast(tuple[int, int], node.args[2]) - return kernel_check(kernel) and dim_check(shape) and stride_check(stride) + if not kernel_check(kernel): + self.reporter.report_reject( + node, + f"Maxpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}", + ) + return False + if not dim_check(shape): + self.reporter.report_reject( + node, + f"Maxpool2d needs N == 1, rest dims <= 65536, got shape {list(shape)}", + ) + return False + if not stride_check(stride): + self.reporter.report_reject( + node, f"Maxpool2d needs stride <= 3, got {stride}" + ) + return False + return True diff --git a/backends/arm/operator_support/reduce_sum_support.py b/backends/arm/operator_support/reduce_sum_support.py index 8345d69caaa..37a71d7264c 100644 --- a/backends/arm/operator_support/reduce_sum_support.py +++ b/backends/arm/operator_support/reduce_sum_support.py @@ -34,6 +34,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): for dim in dim_list: if not 1 <= input_shape[dim] <= 65536: + self.reporter.report_reject( + node, f"sum needs dims < 65536, got shape {input_shape}" + ) return False # We can't be certain of which dim is the last in memory yet, @@ -45,7 +48,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): for length in input_shape[dim + 1 :]: post_R_product *= length if not 1 <= pre_R_product <= 65536: + self.reporter.report_reject(node, "Failed dim check") return False if not 1 <= post_R_product <= 65536: + self.reporter.report_reject(node, "Failed dim check") return False return True diff --git a/backends/arm/operator_support/to_copy_support.py b/backends/arm/operator_support/to_copy_support.py index c81c8e58a29..7926b3dc053 100644 --- a/backends/arm/operator_support/to_copy_support.py +++ b/backends/arm/operator_support/to_copy_support.py @@ -75,9 +75,6 @@ def is_node_tosa_supported( ) -> bool: assert node.target in self.targets - if tosa_spec not in self.tosa_specs: - return False - assert tosa_spec.support_integer() supported_dtypes = ( self.ALL_SUPPORTED_TYPES @@ -97,9 +94,9 @@ def is_node_tosa_supported( assert isinstance(input_val, torch._subclasses.FakeTensor) input_dtype = input_val.dtype if input_dtype not in supported_dtypes: - logger.info( - f"Input dtype {input_val.dtype} is not supported in " - f"{node.target.name()}." # type: ignore[union-attr] # pyre-ignore[16] + self.reporter.report_reject( + node, + f"Input dtype {input_val.dtype} is not supported in {node.target}.", ) return False @@ -107,20 +104,22 @@ def is_node_tosa_supported( output_val = node.meta["val"] assert isinstance(output_val, torch._subclasses.FakeTensor) if output_val.dtype not in supported_dtypes[input_dtype]: - logger.info( + self.reporter.report_reject( + node, f"Output dtype {output_val.dtype} is not supported in " - f"{node.target.name()} for input dtype {input_dtype}. " # type: ignore[union-attr] # pyre-ignore[16] + f"{node.target} for input dtype {input_dtype}. " f"Supported output types: " - f"{''.join(str(t) for t in supported_dtypes[input_dtype])}" + f"{''.join(str(t) for t in supported_dtypes[input_dtype])}", ) return False # Check memory format (to_copy) if "memory_format" in node.kwargs: if node.kwargs["memory_format"] in (torch.preserve_format,): - logger.info( + self.reporter.report_reject( + node, f"Argument 'memory_format' is not supported for " - f"{node.target.name()} right now." # type: ignore[union-attr] # pyre-ignore[16] + f"{node.target} right now.", ) return False @@ -129,9 +128,10 @@ def is_node_tosa_supported( dim_order = node.kwargs["dim_order"] # pyre-ignore[6] if dim_order != list(range(len(dim_order))): # type: ignore[arg-type] - logger.info( + self.reporter.report_reject( + node, f"Argument {dim_order=} is not supported for " - f"{node.target.name()} right now." # type: ignore[union-attr] # pyre-ignore[16] + f"{node.target} right now.", ) return False diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 223b5d40ea1..dfd8024e4b3 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -19,6 +19,7 @@ ) from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification from executorch.exir import ExportedProgram +from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops from torch.export.graph_signature import InputKind from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase @@ -30,8 +31,9 @@ class SupportedTOSAOperatorCheck(OperatorSupportBase): Supported OP for TOSA lowering """ - def __init__(self, tosa_spec: TosaSpecification): + def __init__(self, tosa_spec: TosaSpecification, reporter: WhyNoPartitionReporter): self.tosa_spec = tosa_spec + self.reporter = reporter # Should be populated by subclass implementation tosa_specs: list[TosaSpecification] = [] @@ -86,23 +88,42 @@ def get_registered_tosa_support_checks( def tosa_support_factory( tosa_spec: TosaSpecification, exported_program: ExportedProgram, + reporter: WhyNoPartitionReporter, additional_checks: Optional[Sequence[OperatorSupportBase]] = None, ) -> OperatorSupportBase: - negative_checks: list[OperatorSupportBase] = [CheckInt64Inputs(exported_program)] + """Generates an OperatorSupport class depending on the given `tosa_spec`. + Additional checks can be supplied to avoid partitioning additional nodes. + """ + # Postive checks: Add nodes to partitioning + positive_checks: list[OperatorSupportBase] = [ + BaseTOSASupportList(), + *[ + check(tosa_spec, reporter) + for check in get_registered_tosa_support_checks(tosa_spec) + ], + ] + + # Negative checks: Remove nodes from partitioning + negative_checks: list[OperatorSupportBase] = [ + CheckInt64Inputs(exported_program, reporter), + *[ + reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}") + for check in (additional_checks if additional_checks else []) + ], + ] + if not tosa_spec.support_float(): - negative_checks.append(NeedsDecompositionCheck()) - negative_checks.append(CheckProperQuantization()) - negative_checks.append(EthosU55NotSupported(tosa_spec)) + negative_checks.append(NeedsDecompositionCheck(reporter)) + negative_checks.append(CheckProperQuantization(reporter)) + if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset: + negative_checks.append(EthosU55NotSupported(reporter)) + return chain( - any_chain( - BaseTOSASupportList(), - *( - check(tosa_spec) - for check in get_registered_tosa_support_checks(tosa_spec) - ), + reporter.wrap_check( + any_chain(*positive_checks), + "Not included in BaseTOSASupportList or a registered tosa_support_check", ), *negative_checks, - *additional_checks if additional_checks else [], ) @@ -186,39 +207,39 @@ def is_node_supported( class EthosU55NotSupported(OperatorSupportBase): """ - Certain operators are not supported on U55. These are listed in `unsupported` in - is_node_supported(). + Certain operators are not supported on U55. These are listed in `unsupported_ops`. """ - def __init__(self, tosa_spec: TosaSpecification): - self.tosa_spec = tosa_spec + unsupported_ops = [ + exir_ops.edge.aten.any.default, + exir_ops.edge.aten.any.dim, + exir_ops.edge.aten.any.dims, + exir_ops.edge.aten.bitwise_and.Tensor, + exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.bitwise_xor.Tensor, + exir_ops.edge.aten.logical_and.default, + exir_ops.edge.aten.logical_or.default, + exir_ops.edge.aten.logical_xor.default, + exir_ops.edge.aten.logical_not.default, + exir_ops.edge.aten.amax.default, + exir_ops.edge.aten.amin.default, + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.lt.Tensor, + ] + + def __init__(self, reporter: WhyNoPartitionReporter): + self.reporter = reporter def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: - unsupported_ops = [ - exir_ops.edge.aten.any.default, - exir_ops.edge.aten.any.dim, - exir_ops.edge.aten.any.dims, - exir_ops.edge.aten.bitwise_and.Tensor, - exir_ops.edge.aten.bitwise_or.Tensor, - exir_ops.edge.aten.bitwise_xor.Tensor, - exir_ops.edge.aten.logical_and.default, - exir_ops.edge.aten.logical_or.default, - exir_ops.edge.aten.logical_xor.default, - exir_ops.edge.aten.logical_not.default, - exir_ops.edge.aten.amax.default, - exir_ops.edge.aten.amin.default, - exir_ops.edge.aten.eq.Tensor, - exir_ops.edge.aten.ge.Tensor, - exir_ops.edge.aten.gt.Tensor, - exir_ops.edge.aten.le.Tensor, - exir_ops.edge.aten.lt.Tensor, - ] - if node.target in unsupported_ops: - return False + if node.target in self.unsupported_ops: + self.reporter.report_reject(node, "Op is not supported on U55.") + return False return True @@ -230,6 +251,9 @@ class NeedsDecompositionCheck(OperatorSupportBase): that need to be decomposed. """ + def __init__(self, reporter: WhyNoPartitionReporter): + self.reporter = reporter + def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: @@ -238,22 +262,27 @@ def is_node_supported( return True if node.target == exir_ops.edge.aten.mean.dim: dim = node.args[1] - return dim == [-1, -2] - needs_decomp = node.target in [ - exir_ops.edge.aten.div.Tensor, - exir_ops.edge.aten._native_batch_norm_legit_no_training.default, - exir_ops.edge.aten.native_layer_norm.default, - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten._softmax.default, - exir_ops.edge.aten._log_softmax.default, - exir_ops.edge.aten.var.correction, - exir_ops.edge.aten.var.dim, - exir_ops.edge.aten.add.Scalar, - exir_ops.edge.aten.sub.Scalar, - exir_ops.edge.aten.mul.Scalar, - exir_ops.edge.aten.div.Scalar, - ] - return not needs_decomp + needs_decomp = dim != [-1, -2] + else: + needs_decomp = node.target in [ + exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten._native_batch_norm_legit_no_training.default, + exir_ops.edge.aten.native_layer_norm.default, + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten._log_softmax.default, + exir_ops.edge.aten.var.correction, + exir_ops.edge.aten.var.dim, + exir_ops.edge.aten.add.Scalar, + exir_ops.edge.aten.sub.Scalar, + exir_ops.edge.aten.mul.Scalar, + exir_ops.edge.aten.div.Scalar, + ] + if needs_decomp: + self.reporter.report_reject(node, "Needs to be decomposed.") + return False + else: + return True class CheckProperQuantization(OperatorSupportBase): @@ -266,6 +295,9 @@ class CheckProperQuantization(OperatorSupportBase): dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + def __init__(self, reporter: WhyNoPartitionReporter): + self.reporter = reporter + def _is_matmul_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ): @@ -294,14 +326,23 @@ def _is_matmul_node_supported( for input_node in matched_partition.input_nodes ) if not input_quantized: + self.reporter.report_reject( + node, "One or more matmul inputs were not quantized." + ) return False output_quantized = all( output_node_user.target == self.q_op for output_node_user in matched_partition.output_nodes[0].users ) if not output_quantized: + self.reporter.report_reject( + node, "One or more matmul outputs were not quantized." + ) return False else: + self.reporter.report_reject( + node, "Node did not match any matmul source partition." + ) return False return True @@ -367,6 +408,7 @@ def is_node_supported( ) if not input_quantized: + self.reporter.report_reject(node, "One or more inputs were not quantized.") return False all_q_users = all( @@ -376,18 +418,22 @@ def is_node_supported( output_quantized = output_quantized or all_q_users or not is_floating_point if not output_quantized: + self.reporter.report_reject(node, "One or more outputs were not quantized.") return False return True class CheckInt64Inputs(OperatorSupportBase): - def __init__(self, exported_program: ExportedProgram): + def __init__( + self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter + ): self.input_names = [ spec.arg.name for spec in exported_program.graph_signature.input_specs if spec.kind == InputKind.USER_INPUT ] + self.reporter = reporter super().__init__() def is_node_supported( @@ -402,5 +448,9 @@ def is_node_supported( ): tensor = get_first_fake_tensor(input_node) if tensor.dtype == torch.int64: + self.reporter.report_reject( + node, + f"Had int64 input {input_node.name} that couldn't be handled.", + ) return False return True diff --git a/backends/arm/test/misc/test_custom_partition.py b/backends/arm/test/misc/test_custom_partition.py index 8d73e1c7836..00bc4d306ae 100644 --- a/backends/arm/test/misc/test_custom_partition.py +++ b/backends/arm/test/misc/test_custom_partition.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging + import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester @@ -37,7 +39,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): return self.nested(a, b) -def test_single_reject(): +def test_single_reject(caplog): + caplog.set_level(logging.INFO) + module = CustomPartitioning() inputs = module.inputs compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") @@ -57,6 +61,7 @@ def test_single_reject(): .run_method_and_compare_outputs(inputs=inputs) ) assert check.has_rejected_node() + assert "Rejected by DontPartition" in caplog.text def test_multiple_reject(): @@ -83,7 +88,9 @@ def test_multiple_reject(): assert check.has_rejected_node() -def test_torch_op_reject(): +def test_torch_op_reject(caplog): + caplog.set_level(logging.INFO) + module = CustomPartitioning() inputs = module.inputs compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") @@ -103,6 +110,7 @@ def test_torch_op_reject(): .run_method_and_compare_outputs(inputs=inputs) ) assert check.has_rejected_node() + assert "Rejected by DontPartition" in caplog.text def test_string_op_reject(): @@ -128,7 +136,9 @@ def test_string_op_reject(): assert check.has_rejected_node() -def test_name_reject(): +def test_name_reject(caplog): + caplog.set_level(logging.INFO) + module = CustomPartitioning() inputs = module.inputs compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") @@ -148,6 +158,7 @@ def test_name_reject(): .run_method_and_compare_outputs(inputs=inputs) ) assert check.has_rejected_node() + assert "Rejected by DontPartitionName" in caplog.text def test_module_reject(): @@ -172,7 +183,9 @@ def test_module_reject(): assert check.has_rejected_node() -def test_inexact_module_reject(): +def test_inexact_module_reject(caplog): + caplog.set_level(logging.INFO) + module = NestedModule() inputs = module.inputs compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") @@ -192,6 +205,7 @@ def test_inexact_module_reject(): .run_method_and_compare_outputs(inputs=inputs) ) assert check.has_rejected_node() + assert "Rejected by DontPartitionModule" in caplog.text def test_module_instance_reject(): diff --git a/backends/arm/tosa_partitioner.py b/backends/arm/tosa_partitioner.py index 228998d82f5..a53bf6fc725 100644 --- a/backends/arm/tosa_partitioner.py +++ b/backends/arm/tosa_partitioner.py @@ -25,7 +25,7 @@ Partitioner, PartitionResult, ) -from executorch.exir.backend.utils import tag_constant_data +from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) -logger.setLevel(logging.WARNING) +logger.setLevel(logging.INFO) TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" if TOSA_DBG_VERBOSE: logging.basicConfig(level=logging.INFO) @@ -78,9 +78,13 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # no logger.info(f"Partitioning for {self.delegation_spec.backend_id}: {tosa_spec}") + reporter = WhyNoPartitionReporter() + operator_support = tosa_support_factory( + tosa_spec, exported_program, reporter, self.additional_checks + ) capability_partitioner = CapabilityBasedPartitioner( exported_program.graph_module, - tosa_support_factory(tosa_spec, exported_program, self.additional_checks), + operator_support, allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() @@ -119,14 +123,17 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: if is_partitioned(input): continue if get_first_fake_tensor(input).dtype.is_floating_point: - logger.info( - f"Not partitioning {node.name} becuase input {input.name} has floating point dtype." + reporter.report_reject( + node, + f"Was first node in partition and input {input.name} had fp dtype.", ) del node.meta["delegation_tag"] break tag_constant_data(exported_program) - + logger.info(f"The following nodes were rejected for {tosa_spec}:") + logger.info("\n" + reporter.get_table_report()) + logger.info("(Placeholders and outputs are not included in this list)") return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags ) diff --git a/exir/backend/utils.py b/exir/backend/utils.py index 9487c59a848..eb9aeb19756 100644 --- a/exir/backend/utils.py +++ b/exir/backend/utils.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,7 +9,7 @@ import logging import operator -from collections import defaultdict +from collections import defaultdict, OrderedDict from functools import lru_cache from typing import Dict, Iterable, List, Optional, Set, Tuple, Union @@ -22,9 +23,11 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.lowered_backend_module import create_submodule_from_nodes +from tabulate import tabulate from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param from torch.fx.experimental.symbolic_shapes import has_free_symbols from torch.fx.node import Node +from torch.fx.passes.operator_support import OperatorSupportBase from torch.fx.passes.utils.source_matcher_utils import SourcePartition T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default @@ -569,3 +572,90 @@ def __call__(self, node: torch.fx.Node, reason: str) -> None: def __str__(self) -> str: return f"WhyNoPartition: Node {self.node} was not partitioned because {self.reason}." + + +class WhyNoPartitionReporter: + """ + Helper class for partitioners to gather why nodes were not lowered in a single report. + If a node is reported multiple times, only the first report is included. + + Example usage: + + # In your backend partitioner file(s) + reporter = WhyNoPartitionReporter() + + # hypothetical function that checks if a node can be lowered + if not can_be_lowered(node): + reporter.report_reject(node, "This node was not lowered because ...") + + # Back in partitioner + logger.info(reporter.get_table_report()) + """ + + def __init__(self): + self._rejected_nodes: OrderedDict[torch.fx.Node, str] = ( + OrderedDict() + ) # {Rejected node: reason} + + def report_reject(self, node: torch.fx.Node, reason: str): + """Report a node that was rejected from a partition, along with a reason for why.""" + if node not in self._rejected_nodes: + self._rejected_nodes[node] = reason + + def get_table_report(self) -> str: + """Returns a string containing a table listing all rejected nodes. + The table looks something like this: + ╒══════════════════════════╤══════════════════════════╤═════════════════════════════════════╤═════════════════════════════════════╕ + │ Node name │ Target │ Torch func │ Reason │ + ╞══════════════════════════╪══════════════════════════╪═════════════════════════════════════╪═════════════════════════════════════╡ + │ aten_convolution_default │ aten.convolution.default │ ('conv2d_1', 'builtin_function_or_m │ Convolution needs to have │ + │ │ │ ethod.conv2d') │ kernel_y<=64, │ + │ │ │ │ kernel_x*kernel_y<=4096, got kernel │ + │ │ │ │ (2, 65) │ + ╘══════════════════════════╧══════════════════════════╧═════════════════════════════════════╧═════════════════════════════════════╛ + """ + reject_report = [] + for node in self._rejected_nodes: + if node.op == "placeholder" or node.op == "output": + continue + if not (target := getattr(node.target, "_op", None)): + target = node.target + torch_fn = node.meta.get("torch_fn", "-") + reject_report.append( + [node.name, target, torch_fn, self._rejected_nodes[node]] + ) + if len(reject_report) > 0: + return tabulate( + reject_report, + ["Node name", "Target", "Torch func", "Reason"], + tablefmt="fancy_grid", + maxcolwidths=35, + ) + else: + return "No nodes rejected." + + def wrap_check( + self, operator_support: OperatorSupportBase, message: str + ) -> OperatorSupportBase: + """Wrap the operator_support, reporting rejects with the specified message.""" + return ReportRejected(operator_support, self, message) + + +class ReportRejected(OperatorSupportBase): + """Class for wrapping a OperatorSupportBase, reporting rejects with the specified message to `reporter`.""" + + def __init__( + self, + operator_support: OperatorSupportBase, + reporter: WhyNoPartitionReporter, + message, + ): + self.operator_support = operator_support + self.reporter = reporter + self.message = message + + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + is_supported = self.operator_support.is_node_supported(submodules, node) + if not is_supported: + self.reporter.report_reject(node, self.message) + return is_supported