Skip to content

Commit 736b813

Browse files
committed
Update on "Docs for lower smaller models to mps/coreml/qnn"
Differential Revision: [D56340028](https://our.internmc.facebook.com/intern/diff/D56340028/) [ghstack-poisoned]
2 parents 691bc59 + 4c3f49a commit 736b813

File tree

10 files changed

+249
-24
lines changed

10 files changed

+249
-24
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
3030
}
3131
QNN_TENSOR_TYPE_MAP = {
32+
torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
3233
torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
3334
torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8,
3435
torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16,

backends/qualcomm/partition/common_defs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
exir_ops.edge.aten.clone.default,
1414
exir_ops.edge.aten.index.Tensor,
1515
exir_ops.edge.aten.full.default,
16+
exir_ops.edge.aten.slice_scatter.default,
17+
exir_ops.edge.aten.index_put.default,
1618
]
1719

1820
allow_list_operator = [

examples/models/llama2/README.md

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@ For Llama3, we can use the same process. Note that it's only supported in the Ex
2424
## Quantization:
2525
We employed 4-bit groupwise per token dynamic quantization of all the linear layers of the model. Dynamic quantization refers to quantizating activations dynamically, such that quantization parameters for activations are calculated, from min/max range, at runtime. Here we quantized activations with 8bits (signed integer). Furthermore, weights are statically quantized. In our case weights were per-channel groupwise quantized with 4bit signed integer. For more information refer to this [page](https://github.com/pytorch-labs/ao/).
2626

27-
We evaluated WikiText perplexity using [LM Eval](https://github.com/EleutherAI/lm-evaluation-harness). Below are the results for two different groupsizes.
27+
We evaluated WikiText perplexity using [LM Eval](https://github.com/EleutherAI/lm-evaluation-harness). Below are the results for two different groupsizes, with max_seq_len 2048, and 1000 samples.
2828

29-
|Llama 2 | Baseline (FP32) | Groupwise 4-bit (128) | Groupwise 4-bit (256)
29+
|Model | Baseline (FP32) | Groupwise 4-bit (128) | Groupwise 4-bit (256)
3030
|--------|-----------------| ---------------------- | ---------------
31-
|Wikitext Perplexity | 9.16 | 10.2 | 10.7
31+
|Llama 2 7B | 9.2 | 10.2 | 10.7
32+
|Llama 3 8B | 7.9 | 9.4 | 9.7
3233

3334
Note that groupsize less than 128 was not enabled, since such model were still too large. This is because our current efforts have focused on enabling FP32 and support for FP16 is under way. What this implies for model size is that 1) embedding table is in FP32 and 2) quantized weights scales are FP32.
3435

@@ -56,7 +57,7 @@ Performance was measured on Samsung Galaxy S22, S24, One Plus 12 and iPhone 15 m
5657
- For Llama7b, your device may require at least 32GB RAM. If this is a constraint for you, please try the smaller stories model.
5758

5859
## Step 1: Setup
59-
1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch
60+
1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. For installation run `./install_requirements.sh --pybind xnnpack`
6061
2. Run `examples/models/llama2/install_requirements.sh` to install a few dependencies.
6162

6263
## Step 2: Prepare model
@@ -102,6 +103,16 @@ If you want to deploy and run a smaller model for educational purposes. From `ex
102103
python -m examples.models.llama2.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
103104
```
104105
106+
### Option C: Download and export Llama3 8B model
107+
108+
You can export and run the original Llama3 8B model.
109+
110+
1. Llama3 pretrained parameters can be downloaded from [Meta's official llama3 repository](https://github.com/meta-llama/llama3/).
111+
112+
2. Export model and generate `.pte` file
113+
```
114+
python -m examples.models.llama2.export_llama --checkpoint <consolidated.00.pth> -p <params.json> -d=fp32 -X -qmode 8da4w -kv --use_sdpa_with_kv_cache --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte" group_size 128 --metadata '{"get_bos_id":128000, "get_eos_id":128001}' --embedding-quantize 4,32
115+
```
105116
106117
## (Optional) Finetuning
107118
@@ -147,6 +158,7 @@ The Wikitext results generated above used: `{max_seq_len: 2048, limit: 1000}`
147158
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
148159
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
149160
-DEXECUTORCH_BUILD_XNNPACK=ON \
161+
-DEXECUTORCH_BUILD_QUANTIZED=ON \
150162
-DEXECUTORCH_BUILD_OPTIMIZED=ON \
151163
-DEXECUTORCH_BUILD_CUSTOM=ON \
152164
-Bcmake-out .
@@ -162,17 +174,22 @@ The Wikitext results generated above used: `{max_seq_len: 2048, limit: 1000}`
162174
-DEXECUTORCH_BUILD_CUSTOM=ON \
163175
-DEXECUTORCH_BUILD_OPTIMIZED=ON \
164176
-DEXECUTORCH_BUILD_XNNPACK=ON \
177+
-DEXECUTORCH_BUILD_QUANTIZED=ON \
165178
-Bcmake-out/examples/models/llama2 \
166179
examples/models/llama2
167180
168181
cmake --build cmake-out/examples/models/llama2 -j16 --config Release
169182
```
170183
184+
For Llama3, add `-DEXECUTORCH_USE_TIKTOKEN=ON` option when building the llama runner.
185+
171186
3. Run model. Run options available [here](https://github.com/pytorch/executorch/blob/main/examples/models/llama2/main.cpp#L18-L40).
172187
```
173188
cmake-out/examples/models/llama2/llama_main --model_path=<model pte file> --tokenizer_path=<tokenizer.bin> --prompt=<prompt>
174189
```
175190
191+
For Llama3, you can pass the original `tokenizer.model` (without converting to `.bin` file).
192+
176193
## Step 5: Run benchmark on Android phone
177194
178195
**1. Build llama runner binary for Android**
@@ -280,7 +297,7 @@ This example tries to reuse the Python code, with minimal modifications to make
280297
```
281298
git clean -xfd
282299
pip uninstall executorch
283-
./install_requirements.sh <options>
300+
./install_requirements.sh --pybind xnnpack
284301

285302
rm -rf cmake-out
286303
```

examples/models/llama2/eval_llama_lib.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,11 @@ def __init__(
4242
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
4343
max_seq_length: Optional[int] = None,
4444
):
45-
super().__init__()
45+
device = "cuda" if torch.cuda.is_available() else "cpu"
46+
super().__init__(device=device)
4647
self._model = model
4748
self._tokenizer = tokenizer
48-
self._device = (
49-
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
50-
)
49+
self._device = torch.device(device)
5150
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
5251

5352
@property

examples/models/llama2/export_llama_lib.py

Lines changed: 146 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import argparse
1010
import copy
1111
import logging
12+
import math
1213
import os
1314
import shlex
1415

@@ -143,6 +144,80 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
143144
return module
144145

145146

147+
class SDPASimple(torch.nn.Module):
148+
149+
def __init__(
150+
self,
151+
kv_cache: KVCache,
152+
dim: int,
153+
head_dim: int,
154+
n_rep: int,
155+
):
156+
super().__init__()
157+
self.kv_cache = kv_cache
158+
self.dim = dim
159+
self.head_dim = head_dim
160+
self.n_rep = n_rep
161+
162+
def forward(
163+
self,
164+
input_pos: torch.Tensor,
165+
q: torch.Tensor,
166+
k: torch.Tensor,
167+
v: torch.Tensor,
168+
bsz,
169+
seqlen,
170+
mask,
171+
):
172+
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
173+
k = k.transpose(1, 2)
174+
v = v.transpose(1, 2)
175+
176+
k, v = self.kv_cache.update(input_pos, k, v)
177+
attn_mask = mask[None, None, input_pos]
178+
179+
k = k.repeat_interleave(self.n_rep, dim=1)
180+
v = v.repeat_interleave(self.n_rep, dim=1)
181+
scale_factor = 1 / math.sqrt(q.size(-1))
182+
attn_weight = q @ k.transpose(-2, -1) * scale_factor
183+
attn_weight += attn_mask
184+
attn_weight = torch.softmax(attn_weight, dim=-1)
185+
y = attn_weight @ v
186+
187+
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
188+
189+
190+
def replace_sdpa_with_simple_sdpa(module: torch.nn.Module):
191+
for name, child in module.named_children():
192+
if isinstance(child, SDPA):
193+
setattr(
194+
module,
195+
name,
196+
SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep),
197+
)
198+
else:
199+
replace_sdpa_with_simple_sdpa(child)
200+
return module
201+
202+
203+
def replace_causal_mask(module: torch.nn.Module):
204+
for buffer_fqn_name, buffer in module.named_buffers():
205+
buffer_name = buffer_fqn_name.split(".")[-1]
206+
if buffer_name == "mask":
207+
max_seq_len = buffer.shape[-1]
208+
mask = torch.full(
209+
(max_seq_len, max_seq_len),
210+
float("-inf"),
211+
device="cpu",
212+
)
213+
214+
mask = torch.triu(mask, diagonal=1)
215+
module.register_buffer(buffer_name, mask)
216+
for _, child in module.named_children():
217+
replace_causal_mask(child)
218+
return module
219+
220+
146221
def quantize(
147222
model: torch.nn.Module,
148223
qmode: str,
@@ -280,6 +355,13 @@ def build_args_parser() -> argparse.ArgumentParser:
280355
parser.add_argument(
281356
"--pt2e_quantize",
282357
default=None,
358+
choices=[
359+
"xnnpack_dynamic",
360+
"xnnpack_dynamic_qc4",
361+
"qnn_8a8w",
362+
"qnn_16a16w",
363+
"qnn_16a4w",
364+
],
283365
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
284366
)
285367
parser.add_argument(
@@ -539,7 +621,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
539621
bitwidth = int(bitwidth)
540622
transforms.append(
541623
lambda model: EmbeddingQuantHandler(
542-
model, bitwidth=bitwidth, group_size=group_size
624+
model,
625+
bitwidth=bitwidth,
626+
group_size=group_size,
627+
packed=(bitwidth == 4),
543628
).quantized_model()
544629
)
545630

@@ -549,6 +634,9 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
549634
if args.use_sdpa_with_kv_cache:
550635
transforms.append(replace_sdpa_with_custom_op)
551636

637+
if args.qnn and args.use_kv_cache:
638+
transforms.append(replace_sdpa_with_simple_sdpa)
639+
transforms.append(replace_causal_mask)
552640
return (
553641
load_llama_model(
554642
modelname=modelname,
@@ -572,13 +660,16 @@ def _export_llama(modelname, args) -> str: # noqa: C901
572660
# export_to_edge
573661
pt2e_quant_params = _get_pt2e_quantization_params(args)
574662
quantizers = get_pt2e_quantizers(pt2e_quant_params, args)
575-
if args.qnn:
576-
assert (
577-
args.quantization_mode is None
578-
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
663+
quant_dtype = None
664+
if args.qnn and args.pt2e_quantize:
579665
try:
580666
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer`
581-
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
667+
from executorch.backends.qualcomm.quantizer.quantizer import (
668+
get_16a4w_qnn_ptq_config,
669+
get_default_16bit_qnn_ptq_config,
670+
QnnQuantizer,
671+
QuantDtype,
672+
)
582673

583674
# reset quantizers and pt2e_quant_params from xnnpack backend
584675
pt2e_quant_params = None
@@ -588,10 +679,41 @@ def _export_llama(modelname, args) -> str: # noqa: C901
588679
"Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html"
589680
)
590681

682+
backend, quant_config = args.pt2e_quantize.split("_")
683+
assert (
684+
backend == "qnn"
685+
), f"The quantization config is for backend {backend} instead of qnn."
591686
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
592687
qnn_quantizer = QnnQuantizer()
593688
# more custom quantization are supported including 16a4w etc. default to 8bit quantized
594689
custom_annotations = ()
690+
if quant_config == "8a8w":
691+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
692+
quant_dtype = QuantDtype.use_8a8w
693+
pass
694+
elif quant_config == "16a16w":
695+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
696+
quant_dtype = QuantDtype.use_16a16w
697+
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
698+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
699+
qnn_quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config())
700+
elif quant_config == "16a4w":
701+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
702+
quant_dtype = QuantDtype.use_16a4w
703+
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
704+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
705+
qnn_quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config())
706+
qnn_quantizer.set_per_channel_weight_dtype(
707+
weight_dtype_for_16bit_act="int4"
708+
)
709+
else:
710+
raise AssertionError(
711+
f"No support for quant type {quant_config}. Support 8a8w, 16a16w and 16a4w."
712+
)
713+
714+
assert (
715+
args.quantization_mode is None
716+
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
595717
qnn_quantizer.add_custom_quant_annotations(custom_annotations)
596718
quantizers.append(qnn_quantizer)
597719

@@ -708,25 +830,38 @@ def _export_llama(modelname, args) -> str: # noqa: C901
708830
"Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html"
709831
)
710832

711-
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
712-
backend_options = generate_htp_compiler_spec(use_fp16=False)
833+
use_fp16 = True
834+
skip_node_op_set = {}
835+
if args.pt2e_quantize:
836+
use_fp16 = False
837+
# TODO: fix the lowering error without skipping nodes
838+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
839+
if quant_dtype == QuantDtype.use_8a8w:
840+
raise NotImplementedError("8a8w for llama is still under development")
841+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
842+
elif quant_dtype == QuantDtype.use_16a16w:
843+
raise NotImplementedError("16a16w for llama is still under development")
844+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
845+
elif quant_dtype == QuantDtype.use_16a4w:
846+
raise NotImplementedError("16a4w for llama is still under development")
713847
partitioners.append(
714848
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
715849
QnnPartitioner(
716850
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
717851
generate_qnn_executorch_compiler_spec(
718852
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
719853
soc_model=QcomChipset.SM8650, # default to SM8650
720-
backend_options=backend_options,
854+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
855+
backend_options=generate_htp_compiler_spec(use_fp16=use_fp16),
721856
debug=False,
722857
saver=False,
723858
),
724859
skip_node_id_set={},
725-
skip_node_op_set={},
860+
skip_node_op_set=skip_node_op_set,
726861
)
727862
)
728863
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
729-
_transform(builder_exported_to_edge.export_program())
864+
_transform(builder_exported_to_edge.edge_manager.exported_program())
730865

731866
if args.generate_etrecord:
732867
if not builder_exported_to_edge.edge_manager:

examples/models/llama2/tests/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
2+
3+
oncall("executorch")
4+
5+
python_unittest(
6+
name = "test_simple_sdpa",
7+
srcs = [
8+
"test_simple_sdpa.py",
9+
],
10+
deps = [
11+
"//caffe2:torch",
12+
"//executorch/examples/models/llama2:export_library",
13+
"//executorch/examples/models/llama2:llama_transformer",
14+
],
15+
)

0 commit comments

Comments
 (0)