diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 9cde50b9c70..881d24bbb5e 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -118,3 +118,29 @@ def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig): if "SDPA" in full_qualified_name: annotate_matmul(node, quantization_config_16a8w) annotate_matmul_input1(node.args[1], quantization_config_8a8w) + + +def custom_annotate_matmul_16a8w(gm: torch.fx.GraphModule): + """ + Annotate matmul op with 16a8w quantization config + """ + + def annotate_matmul(node: Node, quantization_config: QuantizationConfig): + input_qspec_map = {} + input_act = node.args[0] + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + input_act1 = node.args[1] + input_spec1 = quantization_config.weight + input_qspec_map[input_act1] = input_spec1 + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + # Annotate 16a8w for matmul op to get better performance + quantization_config_16a8w = get_16a8w_qnn_ptq_config() + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: + annotate_matmul(node, quantization_config_16a8w)