Skip to content

Model quantization test exception when using LiteRT for MTK mobile phone deployment #53

@Bradywuli

Description

@Bradywuli

Dear Authors!
I tried using the examples in litert-samples/v2/ to convert and compile a model for mobile deployment with MTK. I discovered that the LiteRT_AOT_Compilation_Tutorial.ipynb script in v2/colab doesn't quantize the model. The model weight type is float32, but it can still be deployed on mobile devices, with an inference time of approximately 12ms. According to the page, this should be executed on the NPU.

/*2. kotlin_npu
An advanced implementation with Neural Processing Unit (NPU) support for significantly faster inference on compatible devices.
Features:
CPU, GPU, and NPU delegate support
Optimized for Qualcomm and MediaTek NPU hardware
Same 21-class segmentation as CPU/GPU version
Requires enrollment in the Early Access Program
Performance on Samsung S25 Ultra:
CPU: 120-140ms per frame
GPU: 40-50ms per frame
NPU: 6-12ms per frame (10-20x faster than CPU!)
*/

Does the NPU support float32 operations? This isn't the main issue. Below are two quantization strategies I've implemented for the model.

  1. First, when I calibrate the .tflite model in AI pack directly using ai_edge_quantizer, I get the error "ValueError: Invalid signature_key provided: "serving_default". This likely means the .tflite model doesn't have a signature_key, and ai_edge_quantizer calibration requires a signature_key. How can I resolve this? Secondly, in LiteRT_AOT_Compilation_Tutorial.ipynb,
    I added the line
from ai_edge_torch.generative.quantize import quant_recipes
quant_config = quant_recipes.full_int8_weight_only_recipe()
compiled_models = ai_edge_torch.experimental_add_compilation_backend().convert(
channel_last_selfie_segmentation.eval(), sample_input,quant_config=quant_config)

After quantization, only the Qualcomm result is displayed. The Qualcomm TFLite model is quantized successfully, but the MTK model is not generated. What is the reason for this? Do different backend quantization methods require different flags? (Other quantization settings in quant_recipes also failed)

  1. In addition, PT2EQantizer is used for quantification. The code is as follows:
from ai_edge_torch.quantize.pt2e_quantizer import get_symmetric_quantization_config
from ai_edge_torch.quantize.pt2e_quantizer import PT2EQantizer
from ai_edge_torch.quantize.quant_config import QuantConfig
pt2e_quantizer = PT2EQantizer().set_global( 
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=False))
channel_last_selfie_segmentation = ai_edge_torch.to_channel_last_io( 
selfie_segmentation, args=[0], outputs=[0])
sample_input = (torch.randn(1, 256, 256, 3),)
compiled_models = ai_edge_torch.experimental_add_compilation_backend().convert(
channel_last_selfie_segmentation.eval(), sample_input,quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer))

Finally, both the Qualcomm and MTK models were generated, but neither was quantized! What could be the reason?

In addition, I also referred to Quantization in docs/pytorch_converter/README.md
content,

from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch._export import capture_pre_autograd_graph

from ai_edge_torch.quantize.pt2e_quantizer import get_symmetric_quantization_config
from ai_edge_torch.quantize.pt2e_quantizer import PT2EQantizer
from ai_edge_torch.quantize.quant_config import QuantConfig

pt2e_quantizer = PT2EQantizer().set_global( 
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True)
)

pt2e_torch_model = capture_pre_autograd_graph(torch_model, sample_args)
pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer)

# Run the prepared model with sample input data to ensure that internal observers are populated with correct values
pt2e_torch_model(*sample_args)

# Convert the prepared model to a quantized model
pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)

# Convert to an ai_edge_torch model
pt2e_drq_model = ai_edge_torch.convert(pt2e_torch_model, sample_args, quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer))

Found from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
prepare_pt2e in from torch._export import capture_pre_autograd_graph, convert_pt2e has been removed, and capture_pre_autograd_graph does not find this library

So, If my settings are incorrect, please correct me and provide the correct and detailed code for this quantization step. Thank you very much!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions