diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 3dae32f882e..3f40dc56737 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -29,6 +29,7 @@ QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, } QNN_TENSOR_TYPE_MAP = { + torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8, torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16, diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index b06a5766a63..36a2986f09a 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -13,6 +13,8 @@ exir_ops.edge.aten.clone.default, exir_ops.edge.aten.index.Tensor, exir_ops.edge.aten.full.default, + exir_ops.edge.aten.slice_scatter.default, + exir_ops.edge.aten.index_put.default, ] allow_list_operator = [ diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 18a7950d58b..bc9f1cef513 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -355,6 +355,13 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument( "--pt2e_quantize", default=None, + choices=[ + "xnnpack_dynamic", + "xnnpack_dynamic_qc4", + "qnn_8a8w", + "qnn_16a16w", + "qnn_16a4w", + ], help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.", ) parser.add_argument( @@ -615,6 +622,9 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: if args.use_sdpa_with_kv_cache: transforms.append(replace_sdpa_with_custom_op) + if args.qnn and args.use_kv_cache: + transforms.append(replace_sdpa_with_simple_sdpa) + transforms.append(replace_causal_mask) return ( load_llama_model( checkpoint=checkpoint_path, @@ -636,13 +646,16 @@ def _export_llama(modelname, args) -> str: # noqa: C901 # export_to_edge pt2e_quant_params = _get_pt2e_quantization_params(args) quantizers = get_pt2e_quantizers(pt2e_quant_params, args) - if args.qnn: - assert ( - args.quantization_mode is None - ), "Currently qnn backend only supports QnnQuantizer via pt2e flow" + quant_dtype = None + if args.qnn and args.pt2e_quantize: try: # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer` - from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer + from executorch.backends.qualcomm.quantizer.quantizer import ( + get_16a4w_qnn_ptq_config, + get_default_16bit_qnn_ptq_config, + QnnQuantizer, + QuantDtype, + ) # reset quantizers and pt2e_quant_params from xnnpack backend pt2e_quant_params = None @@ -652,10 +665,41 @@ def _export_llama(modelname, args) -> str: # noqa: C901 "Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html" ) + backend, quant_config = args.pt2e_quantize.split("_") + assert ( + backend == "qnn" + ), f"The quantization config is for backend {backend} instead of qnn." # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. qnn_quantizer = QnnQuantizer() # more custom quantization are supported including 16a4w etc. default to 8bit quantized custom_annotations = () + if quant_config == "8a8w": + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + quant_dtype = QuantDtype.use_8a8w + pass + elif quant_config == "16a16w": + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + quant_dtype = QuantDtype.use_16a16w + qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + qnn_quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config()) + elif quant_config == "16a4w": + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + quant_dtype = QuantDtype.use_16a4w + qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + qnn_quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config()) + qnn_quantizer.set_per_channel_weight_dtype( + weight_dtype_for_16bit_act="int4" + ) + else: + raise AssertionError( + f"No support for quant type {quant_config}. Support 8a8w, 16a16w and 16a4w." + ) + + assert ( + args.quantization_mode is None + ), "Currently qnn backend only supports QnnQuantizer via pt2e flow" qnn_quantizer.add_custom_quant_annotations(custom_annotations) quantizers.append(qnn_quantizer) @@ -769,8 +813,20 @@ def _export_llama(modelname, args) -> str: # noqa: C901 "Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html" ) - # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm` - backend_options = generate_htp_compiler_spec(use_fp16=False) + use_fp16 = True + skip_node_op_set = {} + if args.pt2e_quantize: + use_fp16 = False + # TODO: fix the lowering error without skipping nodes + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + if quant_dtype == QuantDtype.use_8a8w: + raise NotImplementedError("8a8w for llama is still under development") + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + elif quant_dtype == QuantDtype.use_16a16w: + raise NotImplementedError("16a16w for llama is still under development") + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + elif quant_dtype == QuantDtype.use_16a4w: + raise NotImplementedError("16a4w for llama is still under development") partitioners.append( # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm` QnnPartitioner( @@ -778,16 +834,17 @@ def _export_llama(modelname, args) -> str: # noqa: C901 generate_qnn_executorch_compiler_spec( # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. soc_model=QcomChipset.SM8650, # default to SM8650 - backend_options=backend_options, + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + backend_options=generate_htp_compiler_spec(use_fp16=use_fp16), debug=False, saver=False, ), skip_node_id_set={}, - skip_node_op_set={}, + skip_node_op_set=skip_node_op_set, ) ) # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm` - _transform(builder_exported_to_edge.export_program()) + _transform(builder_exported_to_edge.edge_manager.exported_program()) if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: