From 69c71244868795ab05773809b4037d557b858999 Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Wed, 12 Feb 2025 10:07:43 +0100 Subject: [PATCH] Arm backend: Change _is_ok_for_quantization to support output check _is_ok_for_quantization now checks the node itself as well to verify that the node can be quantized. Previously it was only checked by looking at the inputs to the node. This led to TestSplit failing, which is fixed with the change to `is_non_float_tensor` in `arm_quantizer_utils`, which now handles when node.meta["val"] is a `list` of `FakeTensor`. It traverses the list and checks if any of the elements are **not** a `FakeTensor`. If one element is not a `FakeTensor` the function will return `True`. Change-Id: I898cfea5d02a185fbfa30b18a013123c6d3670a5 Signed-off-by: Sebastian Larsson --- backends/arm/quantizer/arm_quantizer_utils.py | 29 +++++++- .../arm/quantizer/quantization_annotator.py | 66 ++++++++++++++----- 2 files changed, 75 insertions(+), 20 deletions(-) diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index d7710bc1989..0ce11b620a6 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -11,7 +11,7 @@ # Utility functions for TOSAQuantizer # -from typing import cast +from typing import cast, Sequence import torch from torch._subclasses import FakeTensor @@ -76,9 +76,32 @@ def is_large_scalar(node: Node, gm: GraphModule): def is_non_float_tensor(node: Node) -> bool: - """Check if the input is not a float tensor, so that we can skip quantization for the node - since observers only works with float Tensors + """Check if the output of a node has a data type other than `torch.float32`. + + If the output is not `torch.float32`, quantization cannot be performed, as + observers only work with floating-point tensors. + + Args: + node (Node): The node to check the output(s) for. + + Returns: + bool: `True` if the data type is not float32, otherwise `False`. + + Note: + - If `node.meta["val"]` is a `list`, the function returns `True` if **any** + element is **not** an instance of `FakeTensor` or does **not** have + `torch.float32` as its data type. + - If node.meta["val"] is missing or is not an instance of `FakeTensor`, the + function returns True. """ + if "val" in node.meta and isinstance(node.meta["val"], Sequence): + return any( + not isinstance(fake_tensor, FakeTensor) + or fake_tensor.dtype != torch.float32 + for fake_tensor in node.meta["val"] + ) + if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): return True + return node.meta["val"].dtype != torch.float32 diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 1c6d05f2557..e9ed6be81f3 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging import operator from dataclasses import dataclass from typing import Callable, List, Optional @@ -11,6 +12,7 @@ import torch.fx from executorch.backends.arm.quantizer import arm_quantizer_utils from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from executorch.backends.arm.tosa_utils import get_node_debug_info from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec from torch.ao.quantization.quantizer.utils import ( _annotate_input_qspec_map, @@ -18,6 +20,8 @@ ) from torch.fx import Node +logger = logging.getLogger(__name__) + @dataclass(frozen=True) class _QuantProperty: @@ -45,19 +49,52 @@ def _as_list(x): def _is_ok_for_quantization( - node: Node, quant_property: _QuantProperty, gm: torch.fx.GraphModule + node: Node, quant_properties: _OpQuantProperties, gm: torch.fx.GraphModule ) -> bool: - if quant_property.optional and ( - quant_property.index >= len(node.args) - or node.args[quant_property.index] is None - ): - return True + """Check if a node can be quantized. + + A node can be quantized if: + - All inputs that are required for quantization are of type `float32` + and are not large scalar values. + - The output of the node itself is of type `float32` and is not a large scalar. + + Args: + node (Node): The node being analyzed. + quant_properties (_OpQuantProperties): Contains quantization properties for + the node, including input and output quantization specifications. + gm (torch.fx.GraphModule): The graph module containing the computational graph. + + Returns: + bool: `True` if the node can be quantized, otherwise `False`. + """ + # Check output + if quant_properties.quant_output is not None: + if not arm_quantizer_utils.is_ok_for_quantization(node, gm): # type: ignore[attr-defined] + logger.debug( + f"Could not quantize node due to output: " + f"{get_node_debug_info(node, gm)}" + ) - for n_arg in _as_list(node.args[quant_property.index]): - assert isinstance(n_arg, Node) - if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined] return False + # Check inputs + for quant_property in quant_properties.quant_inputs: + if quant_property.optional and ( + quant_property.index >= len(node.args) + or node.args[quant_property.index] is None + ): + continue + + for n_arg in _as_list(node.args[quant_property.index]): + assert isinstance(n_arg, Node) + if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined] + logger.debug( + f'could not quantize node due to input "{node}": ' + f"{get_node_debug_info(node, gm)}" + ) + + return False + return True @@ -355,14 +392,9 @@ def any_or_hardtanh_min_zero(n: Node): return quant_properties # Check that each inputs/outputs can be quantized properly with the - # provided QuantProperties - for quant_property in quant_properties.quant_inputs: - if not _is_ok_for_quantization(node, quant_property, gm): - return None # type: ignore[return-value] - - if quant_properties.quant_output is not None: - if not _is_ok_for_quantization(node, quant_properties.quant_output, gm): - return None # type: ignore[return-value] + # provided quantization properties. + if not _is_ok_for_quantization(node, quant_properties, gm): + return None # type: ignore[return-value] return quant_properties