diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index 809b7298eba..6a2c8e205b4 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -9,6 +9,7 @@ import torch from torch._ops import OpOverload +from torch._subclasses import FakeTensor from torch.ao.quantization.quantizer import ( QuantizationAnnotation, @@ -41,6 +42,18 @@ def decorator(annotator: Callable): return decorator +def _is_input_float_tensor(node: Node): + """Check if the input is not a float tensor, so that we can skip quantization for the node + since observers only works with float Tensors + """ + if ( + not isinstance(node, Node) + or "val" not in node.meta + or not isinstance(node.meta["val"], FakeTensor) + ): + return False + return node.meta["val"].dtype == torch.float32 + def _is_annotated(nodes: List[Node]): """ @@ -123,11 +136,11 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None input_qspec_map = {} input_act0 = node.args[0] - if isinstance(input_act0, Node): + if _is_input_float_tensor(input_act0): input_qspec_map[input_act0] = input_act_qspec input_act1 = node.args[1] - if isinstance(input_act1, Node): + if _is_input_float_tensor(input_act1): input_qspec_map[input_act1] = input_act_qspec node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(