Skip to content

Arm backend: Change _is_ok_for_quantization to support output check #9795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# Utility functions for TOSAQuantizer
#

from typing import cast
from typing import cast, Sequence

import torch
from torch._subclasses import FakeTensor
Expand Down Expand Up @@ -76,9 +76,32 @@ def is_large_scalar(node: Node, gm: GraphModule):


def is_non_float_tensor(node: Node) -> bool:
"""Check if the input is not a float tensor, so that we can skip quantization for the node
since observers only works with float Tensors
"""Check if the output of a node has a data type other than `torch.float32`.

If the output is not `torch.float32`, quantization cannot be performed, as
observers only work with floating-point tensors.

Args:
node (Node): The node to check the output(s) for.

Returns:
bool: `True` if the data type is not float32, otherwise `False`.

Note:
- If `node.meta["val"]` is a `list`, the function returns `True` if **any**
element is **not** an instance of `FakeTensor` or does **not** have
`torch.float32` as its data type.
- If node.meta["val"] is missing or is not an instance of `FakeTensor`, the
function returns True.
"""
if "val" in node.meta and isinstance(node.meta["val"], Sequence):
return any(
not isinstance(fake_tensor, FakeTensor)
or fake_tensor.dtype != torch.float32
for fake_tensor in node.meta["val"]
)

if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
return True

return node.meta["val"].dtype != torch.float32
66 changes: 49 additions & 17 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import operator
from dataclasses import dataclass
from typing import Callable, List, Optional
Expand All @@ -11,13 +12,16 @@
import torch.fx
from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.arm.tosa_utils import get_node_debug_info
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)
from torch.fx import Node

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class _QuantProperty:
Expand Down Expand Up @@ -45,19 +49,52 @@ def _as_list(x):


def _is_ok_for_quantization(
node: Node, quant_property: _QuantProperty, gm: torch.fx.GraphModule
node: Node, quant_properties: _OpQuantProperties, gm: torch.fx.GraphModule
) -> bool:
if quant_property.optional and (
quant_property.index >= len(node.args)
or node.args[quant_property.index] is None
):
return True
"""Check if a node can be quantized.

A node can be quantized if:
- All inputs that are required for quantization are of type `float32`
and are not large scalar values.
- The output of the node itself is of type `float32` and is not a large scalar.

Args:
node (Node): The node being analyzed.
quant_properties (_OpQuantProperties): Contains quantization properties for
the node, including input and output quantization specifications.
gm (torch.fx.GraphModule): The graph module containing the computational graph.

Returns:
bool: `True` if the node can be quantized, otherwise `False`.
"""
# Check output
if quant_properties.quant_output is not None:
if not arm_quantizer_utils.is_ok_for_quantization(node, gm): # type: ignore[attr-defined]
logger.debug(
f"Could not quantize node due to output: "
f"{get_node_debug_info(node, gm)}"
)

for n_arg in _as_list(node.args[quant_property.index]):
assert isinstance(n_arg, Node)
if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined]
return False

# Check inputs
for quant_property in quant_properties.quant_inputs:
if quant_property.optional and (
quant_property.index >= len(node.args)
or node.args[quant_property.index] is None
):
continue

for n_arg in _as_list(node.args[quant_property.index]):
assert isinstance(n_arg, Node)
if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined]
logger.debug(
f'could not quantize node due to input "{node}": '
f"{get_node_debug_info(node, gm)}"
)

return False

return True


Expand Down Expand Up @@ -355,14 +392,9 @@ def any_or_hardtanh_min_zero(n: Node):
return quant_properties

# Check that each inputs/outputs can be quantized properly with the
# provided QuantProperties
for quant_property in quant_properties.quant_inputs:
if not _is_ok_for_quantization(node, quant_property, gm):
return None # type: ignore[return-value]

if quant_properties.quant_output is not None:
if not _is_ok_for_quantization(node, quant_properties.quant_output, gm):
return None # type: ignore[return-value]
# provided quantization properties.
if not _is_ok_for_quantization(node, quant_properties, gm):
return None # type: ignore[return-value]

return quant_properties

Expand Down
Loading