Skip to content

Commit c3160d2

Browse files
committed
qnn end to end flow
Pull Request resolved: #3038 Patch a few changes including: - support bool tensor type - support fp16 and fix the 8w8a quantization. - add two non-supported ops (slice_scatter and index_put) in common_defs.py stories model working end to end: AOT: fp16: ``` python -m examples.models.llama2.export_llama -kv --qnn -c stories110M.pt -p params.json ``` quantize: ``` python -m examples.models.llama2.export_llama -kv --qnn --pt2e_quantize -c stories110M.pt -p params.json ``` Runtime: ``` /llama_main --model_path=llama2_fp16_qnn_2.21.pte --tokenizer_path=tokenizer.bin --prompt="Once" ``` Output: ``` Once upon a time, there was a boy named Tim. Tim had a pet dog named Max. Max was a big, strong dog. They liked to play and run in the park. One day, Tim and Max went to the park to play. They saw a cat. The cat was up in a tree. Max wanted to help the cat. He tried to climb the tree, but he could not. Then, something unexpected happened. Max started to climb the tree! He was very strong. Max helped the cat come down. The cat was happy. Tim was so proud of his pet. ``` Stories model is too small and sensitive to qunatization. ghstack-source-id: 222466043 @exported-using-ghexport Differential Revision: [D56119738](https://our.internmc.facebook.com/intern/diff/D56119738/)
1 parent 682f291 commit c3160d2

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

backends/qualcomm/builders/node_visitor.py

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
3030
}
3131
QNN_TENSOR_TYPE_MAP = {
32+
torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
3233
torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
3334
torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8,
3435
torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16,

backends/qualcomm/partition/common_defs.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
exir_ops.edge.aten.clone.default,
1414
exir_ops.edge.aten.index.Tensor,
1515
exir_ops.edge.aten.full.default,
16+
exir_ops.edge.aten.slice_scatter.default,
17+
exir_ops.edge.aten.index_put.default,
1618
]
1719

1820
allow_list_operator = [

examples/models/llama2/export_llama_lib.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import pkg_resources
2121
import torch
22+
import torch.nn.functional as F
2223
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
2324
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
2425
XnnpackDynamicallyQuantizedPartitioner,
@@ -34,7 +35,6 @@
3435
from executorch.sdk.etrecord import generate_etrecord
3536
from executorch.util.activation_memory_profiler import generate_memory_trace
3637
from sentencepiece import SentencePieceProcessor
37-
from torch.nn import functional as F
3838

3939
from .builder import DType, LlamaEdgeManager, load_llama_model, WeightType
4040
from .quant_lib import _get_pt2e_quantization_params, get_pt2e_quantizers
@@ -607,6 +607,8 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
607607
if args.use_sdpa_with_kv_cache:
608608
transforms.append(replace_sdpa_with_custom_op)
609609

610+
if args.qnn and args.use_kv_cache:
611+
transforms.append(replace_sdpa_with_simple_sdpa)
610612
return (
611613
load_llama_model(
612614
checkpoint=checkpoint_path,
@@ -629,7 +631,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901
629631
# export_to_edge
630632
pt2e_quant_params = _get_pt2e_quantization_params(args)
631633
quantizers = get_pt2e_quantizers(pt2e_quant_params, args)
632-
if args.qnn:
634+
if args.qnn and args.pt2e_quantize:
633635
assert (
634636
args.quantization_mode is None
635637
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
@@ -647,7 +649,9 @@ def _export_llama(modelname, args) -> str: # noqa: C901
647649

648650
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
649651
qnn_quantizer = QnnQuantizer()
650-
# more custom quantization are supported including 16a4w etc. default to 8bit quantized
652+
logging.info(
653+
"More custom quantization are supported including 16a4w etc. default to 8bit quantized"
654+
)
651655
custom_annotations = ()
652656
qnn_quantizer.add_custom_quant_annotations(custom_annotations)
653657
quantizers.append(qnn_quantizer)
@@ -763,14 +767,21 @@ def _export_llama(modelname, args) -> str: # noqa: C901
763767
)
764768

765769
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
766-
backend_options = generate_htp_compiler_spec(use_fp16=False)
770+
use_fp16 = False if args.pt2e_quantize else True
771+
if use_fp16:
772+
logging.info("Using fp16 for QNN backend, expect performance degradation")
773+
backend_options = generate_htp_compiler_spec(use_fp16=use_fp16)
774+
soc_model = QcomChipset.SM8650
775+
logging.info(
776+
f"Default to soc {soc_model}, other available options can be found in {QcomChipset}"
777+
)
767778
partitioners.append(
768779
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
769780
QnnPartitioner(
770781
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
771782
generate_qnn_executorch_compiler_spec(
772783
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
773-
soc_model=QcomChipset.SM8650, # default to SM8650
784+
soc_model=soc_model, # default to SM8650
774785
backend_options=backend_options,
775786
debug=False,
776787
saver=False,
@@ -780,7 +791,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901
780791
)
781792
)
782793
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
783-
_transform(builder_exported_to_edge.export_program())
794+
_transform(builder_exported_to_edge.edge_manager.exported_program())
784795

785796
if args.generate_etrecord:
786797
if not builder_exported_to_edge.edge_manager:

0 commit comments

Comments
 (0)