Skip to content

Commit 357e94e

Browse files
committed
Update on "qnn end to end flow"
Patch a few changes including: - support bool tensor type - support fp16 and fix the 8w8a quantization. - add two non-supported ops (slice_scatter and index_put) in common_defs.py stories model working end to end: AOT: fp16: ``` python -m examples.models.llama2.export_llama -kv --qnn -c stories110M.pt -p params.json ``` quantize: ``` python -m examples.models.llama2.export_llama -kv --qnn --pt2e_quantize -c stories110M.pt -p params.json ``` Runtime: ``` /llama_main --model_path=llama2_fp16_qnn_2.21.pte --tokenizer_path=tokenizer.bin --prompt="Once" ``` Output: ``` Once upon a time, there was a boy named Tim. Tim had a pet dog named Max. Max was a big, strong dog. They liked to play and run in the park. One day, Tim and Max went to the park to play. They saw a cat. The cat was up in a tree. Max wanted to help the cat. He tried to climb the tree, but he could not. Then, something unexpected happened. Max started to climb the tree! He was very strong. Max helped the cat come down. The cat was happy. Tim was so proud of his pet. ``` Stories model is too small and sensitive to qunatization. Differential Revision: [D56119738](https://our.internmc.facebook.com/intern/diff/D56119738/) [ghstack-poisoned]
2 parents e46b6b6 + 96586a7 commit 357e94e

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
621621
bitwidth = int(bitwidth)
622622
transforms.append(
623623
lambda model: EmbeddingQuantHandler(
624-
model, bitwidth=bitwidth, group_size=group_size
624+
model,
625+
bitwidth=bitwidth,
626+
group_size=group_size,
627+
packed=(bitwidth == 4),
625628
).quantized_model()
626629
)
627630

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def embedding_byte_dtype_out_meta(
189189

190190
quantized_decomposed_lib.define(
191191
"embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
192-
"int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)",
192+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
193193
)
194194

195195

0 commit comments

Comments
 (0)