Skip to content

Commit 37b1ecc

Browse files
committed
Clean up a few more broken links and sections in new doc flow
1 parent 7a1e3b1 commit 37b1ecc

File tree

6 files changed

+228
-19
lines changed

6 files changed

+228
-19
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# XNNPACK Delegate Internals
2+
3+
This is a high-level overview of the ExecuTorch XNNPACK backend delegate. This high performance delegate is aimed to reduce CPU inference latency for ExecuTorch models. We will provide a brief introduction to the XNNPACK library and explore the delegate’s overall architecture and intended use cases.
4+
5+
## What is XNNPACK?
6+
XNNPACK is a library of highly-optimized neural network operators for ARM, x86, and WebAssembly architectures in Android, iOS, Windows, Linux, and macOS environments. It is an open source project, you can find more information about it on [github](https://github.com/google/XNNPACK).
7+
8+
## What are ExecuTorch delegates?
9+
A delegate is an entry point for backends to process and execute parts of the ExecuTorch program. Delegated portions of ExecuTorch models hand off execution to backends. The XNNPACK backend delegate is one of many available in ExecuTorch. It leverages the XNNPACK third-party library to accelerate ExecuTorch programs efficiently across a variety of CPUs. More detailed information on the delegates and developing your own delegates is available [here](compiler-delegate-and-partitioner.md). It is recommended that you get familiar with that content before continuing on to the Architecture section.
10+
11+
## Architecture
12+
![High Level XNNPACK delegate Architecture](./xnnpack-delegate-architecture.png)
13+
14+
### Ahead-of-time
15+
In the ExecuTorch export flow, lowering to the XNNPACK delegate happens at the `to_backend()` stage. In this stage, the model is partitioned by the `XnnpackPartitioner`. Partitioned sections of the graph are converted to a XNNPACK specific graph represenationed and then serialized via flatbuffer. The serialized flatbuffer is then ready to be deserialized and executed by the XNNPACK backend at runtime.
16+
17+
![ExecuTorch XNNPACK delegate Export Flow](./xnnpack-et-flow-diagram.png)
18+
19+
#### Partitioner
20+
The partitioner is implemented by backend delegates to mark nodes suitable for lowering. The `XnnpackPartitioner` lowers using node targets and module metadata. Some more references for partitioners can be found [here](compiler-delegate-and-partitioner.md)
21+
22+
##### Module-based partitioning
23+
24+
`source_fn_stack` is embedded in the node’s metadata and gives information on where these nodes come from. For example, modules like `torch.nn.Linear` when captured and exported `to_edge` generate groups of nodes for their computation. The group of nodes associated with computing the linear module then has a `source_fn_stack` of `torch.nn.Linear. Partitioning based on `source_fn_stack` allows us to identify groups of nodes which are lowerable via XNNPACK.
25+
26+
For example after capturing `torch.nn.Linear` you would find the following key in the metadata for the addmm node associated with linear:
27+
```python
28+
>>> print(linear_node.meta["source_fn_stack"])
29+
'source_fn_stack': ('fn', <class 'torch.nn.modules.linear.Linear'>)
30+
```
31+
32+
33+
##### Op-based partitioning
34+
35+
The `XnnpackPartitioner` also partitions using op targets. It traverses the graph and identifies individual nodes which are lowerable to XNNPACK. A drawback to module-based partitioning is that operators which come from [decompositions](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py) may be skipped. For example, an operator like `torch.nn.Hardsigmoid` is decomposed into add, muls, divs, and clamps. While hardsigmoid is not lowerable, we can lower the decomposed ops. Relying on `source_fn_stack` metadata would skip these lowerables because they belong to a non-lowerable module, so in order to improve model performance, we greedily lower operators based on the op targets as well as the `source_fn_stack`.
36+
37+
##### Passes
38+
39+
Before any serialization, we apply passes on the subgraphs to prepare the graph. These passes are essentially graph transformations that help improve the performance of the delegate. We give an overview of the most significant passes and their function below. For a description of all passes see [here](https://github.com/pytorch/executorch/tree/main/backends/xnnpack/_passes):
40+
41+
* Channels Last Reshape
42+
* ExecuTorch tensors tend to be contiguous before passing them into delegates, while XNNPACK only accepts channels-last memory layout. This pass minimizes the number of permutation operators inserted to pass in channels-last memory format.
43+
* Conv1d to Conv2d
44+
* Allows us to delegate Conv1d nodes by transforming them to Conv2d
45+
* Conv and BN Fusion
46+
* Fuses batch norm operations with the previous convolution node
47+
48+
#### Serialiazation
49+
After partitioning the lowerable subgraphs from the model, The XNNPACK delegate pre-processes these subgraphs and serializes them via flatbuffer for the XNNPACK backend.
50+
51+
52+
##### Serialization Schema
53+
54+
The XNNPACK delegate uses flatbuffer for serialization. In order to improve runtime performance, the XNNPACK delegate’s flatbuffer [schema](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/serialization/schema.fbs) mirrors the XNNPACK Library’s graph level API calls. The serialized data are arguments to XNNPACK’s APIs, so that at runtime, the XNNPACK execution graph can efficiently be created with successive calls to XNNPACK’s APIs.
55+
56+
### Runtime
57+
The XNNPACK backend’s runtime interfaces with the ExecuTorch runtime through the custom `init` and `execute` function. Each delegated subgraph is contained in an individually serialized XNNPACK blob. When the model is initialized, ExecuTorch calls `init` on all XNNPACK Blobs to load the subgraph from serialized flatbuffer. After, when the model is executed, each subgraph is executed via the backend through the custom `execute` function. To read more about how delegate runtimes interface with ExecuTorch, refer to this [resource](compiler-delegate-and-partitioner.md).
58+
59+
60+
#### **XNNPACK Library**
61+
XNNPACK delegate supports CPU's on multiple platforms; more information on the supported hardware architectures can be found on the XNNPACK Library’s [README](https://github.com/google/XNNPACK).
62+
63+
#### **Init**
64+
When calling XNNPACK delegate’s `init`, we deserialize the preprocessed blobs via flatbuffer. We define the nodes (operators) and edges (intermediate tensors) to build the XNNPACK execution graph using the information we serialized ahead-of-time. As we mentioned earlier, the majority of processing has been done ahead-of-time, so that at runtime we can just call the XNNPACK APIs with the serialized arguments in succession. As we define static data into the execution graph, XNNPACK performs weight packing at runtime to prepare static data like weights and biases for efficient execution. After creating the execution graph, we create the runtime object and pass it on to `execute`.
65+
66+
Since weight packing creates an extra copy of the weights inside XNNPACK, We free the original copy of the weights inside the preprocessed XNNPACK Blob, this allows us to remove some of the memory overhead.
67+
68+
69+
#### **Execute**
70+
When executing the XNNPACK subgraphs, we prepare the tensor inputs and outputs and feed them to the XNNPACK runtime graph. After executing the runtime graph, the output pointers are filled with the computed tensors.
71+
72+
#### **Profiling**
73+
We have enabled basic profiling for the XNNPACK delegate that can be enabled with the compiler flag `-DEXECUTORCH_ENABLE_EVENT_TRACER` (add `-DENABLE_XNNPACK_PROFILING` for additional details). With ExecuTorch's Developer Tools integration, you can also now use the Developer Tools to profile the model. You can follow the steps in [Using the ExecuTorch Developer Tools to Profile a Model](./tutorials/devtools-integration-tutorial) on how to profile ExecuTorch models and use Developer Tools' Inspector API to view XNNPACK's internal profiling information. An example implementation is available in the `xnn_executor_runner` (see [tutorial here](tutorial-xnnpack-delegate-lowering.md#profiling)).
74+
75+
76+
[comment]: <> (TODO: Refactor quantizer to a more official quantization doc)
77+
## Quantization
78+
The XNNPACK delegate can also be used as a backend to execute symmetrically quantized models. For quantized model delegation, we quantize models using the `XNNPACKQuantizer`. `Quantizers` are backend specific, which means the `XNNPACKQuantizer` is configured to quantize models to leverage the quantized operators offered by the XNNPACK Library. We will not go over the details of how to implement your custom quantizer, you can follow the docs [here](https://pytorch.org/tutorials/prototype/pt2e_quantizer.html) to do so. However, we will provide a brief overview of how to quantize the model to leverage quantized execution of the XNNPACK delegate.
79+
80+
### Configuring the XNNPACKQuantizer
81+
82+
```python
83+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
84+
XNNPACKQuantizer,
85+
get_symmetric_quantization_config,
86+
)
87+
quantizer = XNNPACKQuantizer()
88+
quantizer.set_global(get_symmetric_quantization_config())
89+
```
90+
Here we initialize the `XNNPACKQuantizer` and set the quantization config to be symmetrically quantized. Symmetric quantization is when weights are symmetrically quantized with `qmin = -127` and `qmax = 127`, which forces the quantization zeropoints to be zero. `get_symmetric_quantization_config()` can be configured with the following arguments:
91+
* `is_per_channel`
92+
* Weights are quantized across channels
93+
* `is_qat`
94+
* Quantize aware training
95+
* `is_dynamic`
96+
* Dynamic quantization
97+
98+
We can then configure the `XNNPACKQuantizer` as we wish. We set the following configs below as an example:
99+
```python
100+
quantizer.set_global(quantization_config)
101+
.set_object_type(torch.nn.Conv2d, quantization_config) # can configure by module type
102+
.set_object_type(torch.nn.functional.linear, quantization_config) # or torch functional op typea
103+
.set_module_name("foo.bar", quantization_config) # or by module fully qualified name
104+
```
105+
106+
### Quantizing your model with the XNNPACKQuantizer
107+
After configuring our quantizer, we are now ready to quantize our model
108+
```python
109+
from torch.export import export_for_training
110+
111+
exported_model = export_for_training(model_to_quantize, example_inputs).module()
112+
prepared_model = prepare_pt2e(exported_model, quantizer)
113+
print(prepared_model.graph)
114+
```
115+
Prepare performs some Conv2d-BN fusion, and inserts quantization observers in the appropriate places. For Post-Training Quantization, we generally calibrate our model after this step. We run sample examples through the `prepared_model` to observe the statistics of the Tensors to calculate the quantization parameters.
116+
117+
Finally, we convert our model here:
118+
```python
119+
quantized_model = convert_pt2e(prepared_model)
120+
print(quantized_model)
121+
```
122+
You will now see the Q/DQ representation of the model, which means `torch.ops.quantized_decomposed.dequantize_per_tensor` are inserted at quantized operator inputs and `torch.ops.quantized_decomposed.quantize_per_tensor` are inserted at operator outputs. Example:
123+
124+
```python
125+
def _qdq_quantized_linear(
126+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
127+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
128+
bias_fp32,
129+
out_scale, out_zero_point, out_quant_min, out_quant_max
130+
):
131+
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
132+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
133+
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
134+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
135+
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
136+
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
137+
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
138+
return out_i8
139+
```
140+
141+
142+
You can read more indepth explanations on PyTorch 2 quantization [here](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html).
143+
144+
## See Also
145+
- [Integrating XNNPACK Delegate Android App](demo-apps-android.md)
146+
- [Complete the Lowering to XNNPACK Tutorial](tutorial-xnnpack-delegate-lowering.md)

docs/source/backends-xnnpack.md

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,61 @@ The XNNPACK partitioner API allows for configuration of the model delegation to
5151

5252
### Quantization
5353

54-
Placeholder - document available quantization flows (PT2E + ao), schemes, and operators.
54+
The XNNPACK delegate can also be used as a backend to execute symmetrically quantized models. To quantize a PyTorch model for the XNNPACK backend, use the `XNNPACKQuantizer`. `Quantizers` are backend specific, which means the `XNNPACKQuantizer` is configured to quantize models to leverage the quantized operators offered by the XNNPACK Library.
55+
56+
### Configuring the XNNPACKQuantizer
57+
58+
```python
59+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
60+
XNNPACKQuantizer,
61+
get_symmetric_quantization_config,
62+
)
63+
quantizer = XNNPACKQuantizer()
64+
quantizer.set_global(get_symmetric_quantization_config())
65+
```
66+
Here, the `XNNPACKQuantizer` is configured for symmetric quantization, indicating that the quantized zero point is set to zero with `qmin = -127` and `qmax = 127`. `get_symmetric_quantization_config()` can be configured with the following arguments:
67+
* `is_per_channel`
68+
* Weights are quantized across channels
69+
* `is_qat`
70+
* Quantize aware training
71+
* `is_dynamic`
72+
* Dynamic quantization
73+
74+
```python
75+
quantizer.set_global(quantization_config)
76+
.set_object_type(torch.nn.Conv2d, quantization_config) # can configure by module type
77+
.set_object_type(torch.nn.functional.linear, quantization_config) # or torch functional op typea
78+
.set_module_name("foo.bar", quantization_config) # or by module fully qualified name
79+
```
80+
81+
#### Quantizing a model with the XNNPACKQuantizer
82+
After configuring the quantizer, the model can be quantized by via the `prepare_pt2e` and `convert_pt2e` APIs.
83+
```python
84+
from torch.export import export_for_training
85+
86+
exported_model = export_for_training(model_to_quantize, example_inputs).module()
87+
prepared_model = prepare_pt2e(exported_model, quantizer)
88+
89+
for cal_sample in cal_samples: # Replace with representative model inputs
90+
prepared_model(cal_sample) # Calibrate
91+
92+
quantized_model = convert_pt2e(prepared_model)
93+
```
94+
For static, post-training quantization (PTQ), the post-prepare\_pt2e model should beS run with a representative set of samples, which are used to determine the quantization parameters.
95+
96+
After `convert_pt2e`, the model can be exported and lowered using the normal ExecuTorch XNNPACK flow. For more information on PyTorch 2 quantization [here](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html).
97+
98+
### Testing the Model
99+
100+
After generating the XNNPACK-delegated .pte, the model can be tested from Python using the ExecuTorch runtime python bindings. This can be used to sanity check the model and evaluate numerical accuracy. See [Testing the Model](using-executorch-export.md#testing-the-model) for more information.
55101

56102
## Runtime Integration
57103

104+
To run the model on-device, use the standard ExecuTorch runtime APIs. See [Running on Device](getting-started.md#running-on-device) for more information.
105+
58106
The XNNPACK delegate is included by default in the published Android, iOS, and pip packages. When building from source, pass `-DEXECUTORCH_BUILD_XNNPACK=ON` when configuring the CMake build to compile the XNNPACK backend.
59107

60108
To link against the backend, add the `xnnpack_backend` CMake target as a build dependency, or link directly against `libxnnpack_backend`. Due to the use of static registration, it may be necessary to link with whole-archive. This can typically be done by passing the following flags: `-Wl,--whole-archive libxnnpack_backend.a -Wl,--no-whole-archive`.
61109

62110
No additional steps are necessary to use the backend beyond linking the target. Any XNNPACK-delegated .pte file will automatically run on the registered backend.
63111

64-
### Runner
65-
66-
To test XNNPACK models on a development machine, the repository includes a runner binary, which can run XNNPACK delegated models. It is built by default when building the XNNPACK backend. The runner can be invoked with the following command, assuming that the CMake build directory is named cmake-out. Note that the XNNPACK delegate is also available by default from the Python runtime bindings (see [Testing the Model](using-executorch-export.md#testing-the-model) for more information).
67-
```
68-
./cmake-out/backends/xnnpack/xnn_executor_runner --model_path=./mv2_xnnpack.pte
69-
```

docs/source/index.rst

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@ Topics in this section will help you get started with ExecuTorch.
9797
using-executorch-building-from-source
9898
using-executorch-faqs
9999

100+
.. toctree::
101+
:glob:
102+
:maxdepth: 1
103+
:caption: Examples
104+
:hidden:
105+
106+
demo-apps-android.md
107+
demo-apps-ios.md
108+
100109
.. toctree::
101110
:glob:
102111
:maxdepth: 1
@@ -199,6 +208,7 @@ Topics in this section will help you get started with ExecuTorch.
199208
:hidden:
200209

201210
backend-delegates-integration
211+
backend-delegates-xnnpack-reference
202212
backend-delegates-dependencies
203213
compiler-delegate-and-partitioner
204214
debug-backend-delegate
@@ -315,7 +325,7 @@ ExecuTorch tutorials.
315325
:header: Building and Running ExecuTorch with Vulkan Backend
316326
:card_description: A tutorial that walks you through the process of building ExecuTorch with Vulkan Backend
317327
:image: _static/img/generic-pytorch-logo.png
318-
:link: build-run-vulkan.html
328+
:link: backends-vulkan.html
319329
:tags: Export,Backend,Delegation,Vulkan
320330

321331
..
@@ -333,35 +343,35 @@ ExecuTorch tutorials.
333343
:header: Building and Running ExecuTorch with CoreML Backend
334344
:card_description: A tutorial that walks you through the process of building ExecuTorch with CoreML Backend
335345
:image: _static/img/generic-pytorch-logo.png
336-
:link: build-run-coreml.html
346+
:link: backends-coreml.html
337347
:tags: Export,Backend,Delegation,CoreML
338348

339349
.. customcarditem::
340350
:header: Building and Running ExecuTorch with MediaTek Backend
341351
:card_description: A tutorial that walks you through the process of building ExecuTorch with MediaTek Backend
342352
:image: _static/img/generic-pytorch-logo.png
343-
:link: build-run-mediatek-backend.html
353+
:link: backends-mediatek-backend.html
344354
:tags: Export,Backend,Delegation,MediaTek
345355

346356
.. customcarditem::
347357
:header: Building and Running ExecuTorch with MPS Backend
348358
:card_description: A tutorial that walks you through the process of building ExecuTorch with MPSGraph Backend
349359
:image: _static/img/generic-pytorch-logo.png
350-
:link: build-run-mps.html
360+
:link: backends-mps.html
351361
:tags: Export,Backend,Delegation,MPS,MPSGraph
352362

353363
.. customcarditem::
354364
:header: Building and Running ExecuTorch with Qualcomm AI Engine Direct Backend
355365
:card_description: A tutorial that walks you through the process of building ExecuTorch with Qualcomm AI Engine Direct Backend
356366
:image: _static/img/generic-pytorch-logo.png
357-
:link: build-run-qualcomm-ai-engine-direct-backend.html
367+
:link: backends-qualcomm-ai-engine-direct-backend.html
358368
:tags: Export,Backend,Delegation,QNN
359369

360370
.. customcarditem::
361371
:header: Building and Running ExecuTorch on Xtensa HiFi4 DSP
362372
:card_description: A tutorial that walks you through the process of building ExecuTorch for an Xtensa Hifi4 DSP using custom operators
363373
:image: _static/img/generic-pytorch-logo.png
364-
:link: build-run-xtensa.html
374+
:link: backends-cadence.html
365375
:tags: Export,Custom-Operators,DSP,Xtensa
366376

367377
.. customcardend::

docs/source/using-executorch-building-from-source.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,12 @@ cmake --build cmake-out -j9
179179

180180
First, generate an `add.pte` or other ExecuTorch program file using the
181181
instructions as described in
182-
[Setting up ExecuTorch](getting-started-setup.md#building-a-runtime).
182+
[Preparing a Model](getting-started.md#preparing-the-model).
183183

184184
Then, pass it to the command line tool:
185185

186186
```bash
187-
./cmake-out/executor_runner --model_path path/to/add.pte
187+
./cmake-out/executor_runner --model_path path/to/model.pte
188188
```
189189

190190
If it worked, you should see the message "Model executed successfully" followed

0 commit comments

Comments
 (0)