diff --git a/backends/xnnpack/README.md b/backends/xnnpack/README.md index fe359b7adc9..6e1731799d3 100644 --- a/backends/xnnpack/README.md +++ b/backends/xnnpack/README.md @@ -7,8 +7,11 @@ mechanism for leveraging the XNNPACK library to accelerate operators running on CPU. ## Layout -- `runtime/` : Runtime logic used at inference. This contains all the cpp files - used to build the runtime graph and execute the XNNPACK model +- `cmake/` : CMake related files +- `operators`: the directory to store all of op visitors + - `node_visitor.py`: Implementation of serializing each lowerable operator + node + - ... - `partition/`: Partitioner is used to identify operators in model's graph that are suitable for lowering to XNNPACK delegate - `xnnpack_partitioner.py`: Contains partitioner that tags graph patterns @@ -16,10 +19,8 @@ CPU. - `configs.py`: Contains lists of op/modules for XNNPACK lowering - `passes/`: Contains passes which are used before preprocessing to prepare the graph for XNNPACK lowering -- `operators`: the directory to store all of op visitors - - `node_visitor.py`: Implementation of serializing each lowerable operator - node - - ... +- `runtime/` : Runtime logic used at inference. This contains all the cpp files + used to build the runtime graph and execute the XNNPACK model - `serialization/`: Contains files related to serializing the XNNPACK graph representation of the PyTorch model - `schema.fbs`: Flatbuffer schema of serialization format @@ -28,64 +29,107 @@ CPU. - `xnnpack_graph_serialize`: Implementation for serializing dataclasses from graph schema to flatbuffer - `test/`: Tests for XNNPACK Delegate +- `third-party/`: third-party libraries used by XNNPACK Delegate - `xnnpack_preprocess.py`: Contains preprocess implementation which is called by `to_backend` on the graph or subgraph of a model returning a preprocessed blob responsible for executing the graph or subgraph at runtime +## End to End Example + +To further understand the features of the XNNPACK Delegate and how to use it, consider the following end to end example with MobilenetV2. + +### Lowering a model to XNNPACK +```python +import torch +import torchvision.models as models + +from torch.export import export, ExportedProgram +from torchvision.models.mobilenetv2 import MobileNet_V2_Weights +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge +from executorch.exir.backend.backend_api import to_backend + + +mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() +sample_inputs = (torch.randn(1, 3, 224, 224), ) + +exported_program: ExportedProgram = export(mobilenet_v2, sample_inputs) +edge: EdgeProgramManager = to_edge(exported_program) + +edge = edge.to_backend(XnnpackPartitioner()) +``` + +We will go through this example with the [MobileNetV2](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/) pretrained model downloaded from the TorchVision library. The flow of lowering a model starts after exporting the model `to_edge`. We call the `to_backend` api with the `XnnpackPartitioner`. The partitioner identifies the subgraphs suitable for XNNPACK backend delegate to consume. Afterwards, the identified subgraphs will be serialized with the XNNPACK Delegate flatbuffer schema and each subgraph will be replaced with a call to the XNNPACK Delegate. + +```python +>>> print(edge.exported_program().graph_module) +GraphModule( + (lowered_module_0): LoweredBackendModule() + (lowered_module_1): LoweredBackendModule() +) + +def forward(self, arg314_1): + lowered_module_0 = self.lowered_module_0 + executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, arg314_1); lowered_module_0 = arg314_1 = None + getitem = executorch_call_delegate[0]; executorch_call_delegate = None + aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem, [1, 1280]); getitem = None + aten_clone_default = executorch_exir_dialects_edge__ops_aten_clone_default(aten_view_copy_default); aten_view_copy_default = None + lowered_module_1 = self.lowered_module_1 + executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, aten_clone_default); lowered_module_1 = aten_clone_default = None + getitem_1 = executorch_call_delegate_1[0]; executorch_call_delegate_1 = None + return (getitem_1,) +``` + +We print the graph after lowering above to show the new nodes that were inserted to call the XNNPACK Delegate. The subgraphs which are being delegated to XNNPACK are the first argument at each call site. It can be observed that the majority of `convolution-relu-add` blocks and `linear` blocks were able to be delegated to XNNPACK. We can also see the operators which were not able to be lowered to the XNNPACK delegate, such as `clone` and `view_copy`. + +```python +exec_prog = edge.to_executorch() + +with open("xnnpack_mobilenetv2.pte", "wb") as file: + exec_prog.write_to_file(file) +``` +After lowering to the XNNPACK Program, we can then prepare it for executorch and save the model as a `.pte` file. `.pte` is a binary format that stores the serialized ExecuTorch graph. + + +### Running the XNNPACK Model with CMake +After exporting the XNNPACK Delegated model, we can now try running it with example inputs using CMake. We can build and use the xnn_executor_runner, which is a sample wrapper for the ExecuTorch Runtime and XNNPACK Backend. We first begin by configuring the CMake build like such: +```bash +# cd to the root of executorch repo +cd executorch + +# Get a clean cmake-out directory +rm- -rf cmake-out +mkdir cmake-out + +# Configure cmake +cmake \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_XNNPACK=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_ENABLE_LOGGING=1 \ + -DPYTHON_EXECUTABLE=python \ + -Bcmake-out . +``` +Then you can build the runtime componenets with + +```bash +cmake --build cmake-out -j9 --target install --config Release +``` + +Now you should be able to find the executable built at `./cmake-out/backends/xnnpack/xnn_executor_runner` you can run the executable with the model you generated as such +```bash +./cmake-out/backends/xnnpack/xnn_executor_runner --model_path=./mv2_xnnpack_fp32.pte +``` + ## Help & Improvements If you have problems or questions, or have suggestions for ways to make implementation and testing better, please reach out to the PyTorch Edge team or create an issue on [github](https://www.github.com/pytorch/executorch/issues). -## Contributing - -Please follow the following steps and guidelines when adding a new operator -implementation to this library. The goals of these guidelines are to -- Make it straightforward to add new XNNPACK operators. -- Ensure that the newly added operators are of high quality, and are easy to - maintain -- Make it easy for users to find available operator implementations, and to - trust in their quality and behavioral stability. - -### AoT and Serialization Overview -#### Serialization: -XNNPACK delegate uses flatbuffer to serialize its nodes and values. In order to -add -[preprocessing](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/xnnpack_preprocess.py) -support for a new operator, we must add the operator in both the flatbuffer -[schema](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/serialization/schema.fbs), -as well as the mirrored python [data -class](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/serialization/xnnpack_graph_schema.py). -These tables are based on the arguments to the XNNPACK Subgraph APIs. These -APIs can be found -[here](https://github.com/google/xnnpack/blob/master/include/xnnpack.h). We -essentially serialize all the static arguments we need to call `define_{new -operator}()`. - -#### AoT Preprocess: -To add logic to preprocess new operators for the XNNPACK Delegate, we can -create new node_visitors that perform the serialization of the new operator. An -example can be found [here](). The function of these node_visitors is to -serialize all the data we define to need in the schema above. - -#### AoT Partitioner: -XnnpackPartitioner is used to select the pattern (like the linear module -graph) in a big graph such that the selected nodes will be delegated to -XNNPACK. To support a new op (for example, sigmoid), add the corresponding op -or module to the -[config.py](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/partition/configs.py), -which captures the sigmoid op. - -#### How does it work? -- Tag the nodes: in the XNNPACK partitioner's config, which lists all ops that - are supported by the current XNNPACK backend in executorch. When call - `XnnpackPartitioner.partition()`, it will tag all the nodes that matches the - patterns listed in self.pattern -- Lower the nodes; when we call `to_backend(graph_module, XnnpackPartitioner)`, - it will loop through all the tagged nodes, and lower the group with the same - tag. - - -#### Adding Tests for newly minted operators -To test newly added operators, we can add unit tests in: -[tests](https://github.com/pytorch/executorch/tree/main/backends/xnnpack/test) + +## See Also +For more information about the XNNPACK Delegate, please check out the following resources: +- [ExecuTorch XNNPACK Delegate](https://pytorch.org/executorch/0.2/native-delegates-executorch-xnnpack-delegate.html) +- [Building and Running ExecuTorch with XNNPACK Backend](https://pytorch.org/executorch/0.2/native-delegates-executorch-xnnpack-delegate.html) diff --git a/docs/source/native-delegates-executorch-xnnpack-delegate.md b/docs/source/native-delegates-executorch-xnnpack-delegate.md index 12b2e9c2ba7..1d12daef9d8 100644 --- a/docs/source/native-delegates-executorch-xnnpack-delegate.md +++ b/docs/source/native-delegates-executorch-xnnpack-delegate.md @@ -74,16 +74,8 @@ Since weight packing creates an extra copy of the weights inside XNNPACK, We fre When executing the XNNPACK subgraphs, we prepare the tensor inputs and outputs and feed them to the XNNPACK runtime graph. After executing the runtime graph, the output pointers are filled with the computed tensors. #### **Profiling** -We have enabled basic profiling for XNNPACK delegate that can be enabled with the following compiler flag `-DENABLE_XNNPACK_PROFILING`. After running the model it will produce basic per-op and total timings. We provide an example of the profiling below. The timings listed are the average across runs, and the units are in microseconds. +We have enabled basic profiling for XNNPACK delegate that can be enabled with the following compiler flag `-DENABLE_XNNPACK_PROFILING`. With ExecuTorch's SDK integration, you can also now use the SDK tools to profile the model. You can follow the steps in [Using the ExecuTorch SDK to Profile a Model](./tutorials/sdk-integration-tutorial) on how to profile ExecuTorch models and use SDK's Inspector API to view XNNPACK's internal profiling information. -``` -Fully Connected (NC, F32) GEMM: 109.510002 -Total Time: 109.510002 -``` - -::::{note} -Profiling is a work in progress, and is planned to be integrated with [SDK Tools](sdk-delegate-integration.md) and Tensorboard. -:::: [comment]: <> (TODO: Refactor quantizer to a more official quantization doc) ## Quantization