diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index cf5d1b6e6f..d3148c9542 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -35,6 +35,7 @@ ) from executorch.extension.llm.export.quantizer_lib import ( + get_coreml_quantizer, get_pt2e_quantization_params, get_pt2e_quantizers, get_qnn_quantizer, @@ -128,6 +129,11 @@ def build_args_parser() -> argparse.ArgumentParser: "qnn_8a8w", "qnn_16a16w", "qnn_16a4w", + "coreml_c4w", + "coreml_8a_c8w", + "coreml_8a_c4w", + "coreml_baseline_8a_c8w", + "coreml_baseline_8a_c4w", ], help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.", ) @@ -416,6 +422,10 @@ def get_quantizer_and_quant_params(args): args.pt2e_quantize, args.quantization_mode ) quantizers.append(qnn_quantizer) + if args.coreml and args.pt2e_quantize: + assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml" + coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize) + quantizers.append(coreml_quantizer) logging.info(f"Applying quantizers: {quantizers}") return pt2e_quant_params, quantizers, quant_dtype @@ -469,7 +479,10 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 modelname = f"mps_{modelname}" if args.coreml: - partitioners.append(get_coreml_partitioner(args.use_kv_cache)) + coreml_partitioner = get_coreml_partitioner( + args.use_kv_cache, args.pt2e_quantize + ) + partitioners.append(coreml_partitioner) modelname = f"coreml_{modelname}" if args.qnn: diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index e8f32e4b47..bcbeeeee15 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -55,7 +55,9 @@ def get_mps_partitioner(use_kv_cache: bool = False): return MPSPartitioner(compile_specs) -def get_coreml_partitioner(use_kv_cache: bool = False): +def get_coreml_partitioner( + use_kv_cache: bool = False, pt2e_quantize: Optional[str] = None +): assert ( use_kv_cache is True ), "CoreML backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment" @@ -72,7 +74,26 @@ def get_coreml_partitioner(use_kv_cache: bool = False): "Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html" ) + minimum_deployment_target = ct.target.iOS15 + # In Core ML, quantization in introduced in iOS 16 + if pt2e_quantize is not None: + minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS16) + # In Core ML, 8-bit activation quantization is introduced in iOS 17 + if pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"): + minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS17) + # In Core ML, 4-bit weight compression is introduced in iOS 18 + if pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w"): + minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18) + # In Core ML, stateful execution is introduced in iOS 18 + # TODO (https://github.com/pytorch/executorch/issues/4209) + # For now, since mutable buffer is kept in executorch runtime, + # state is out of place and can be handled by older iOS. + # Once mutable buffer can be handed over to delegate, i.e. state becomes in-place, we will have + # if use_kv_cache: + # minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18) + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=minimum_deployment_target, compute_precision=ct.precision(ct.precision.FLOAT16.value), # using `ComputeUnit.ALL` can increase the model load time, default to `ComputeUnit.CPU_AND_GPU` compute_unit=ct.ComputeUnit[ct.ComputeUnit.CPU_AND_GPU.name.upper()], diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 441f673302..fe6ad1c201 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -193,3 +193,52 @@ def get_qnn_quantizer( ), "Currently qnn backend only supports QnnQuantizer via pt2e flow" qnn_quantizer.add_custom_quant_annotations(custom_annotations) return qnn_quantizer, quant_dtype + + +def get_coreml_quantizer(pt2e_quantize: str): + try: + from coremltools.optimize.torch.quantization.quantization_config import ( + LinearQuantizerConfig, + QuantizationScheme, + ) + + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.apple.coreml.quantizer`. + from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer + except ImportError: + raise ImportError( + "Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html" + ) + + if pt2e_quantize == "coreml_8a_c8w": + config = LinearQuantizerConfig.from_dict( + { + "global_config": { + "quantization_scheme": QuantizationScheme.affine, + "activation_dtype": torch.quint8, + "weight_dtype": torch.qint8, + "weight_per_channel": True, + } + } + ) + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `apple`. + quantizer = CoreMLQuantizer(config) + + elif pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w"): + raise NotImplementedError("4-bit Core ML quantizer is still under development") + + elif pt2e_quantize == "coreml_baseline_8a_c8w": + config = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=False + ) + quantizer = XNNPACKQuantizer().set_global(config) + + elif pt2e_quantize == "coreml_baseline_8a_c4w": + config = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=False, weight_qmin=-8, weight_qmax=7 + ) + quantizer = XNNPACKQuantizer().set_global(config) + + else: + raise ValueError(f"Unsupported Core ML quantizer specification {pt2e_quantize}") + + return quantizer