Skip to content

Commit 6f7f55f

Browse files
committed
[ET-VK] Documentation for Vulkan Delegate
ghstack-source-id: b6e86fb Pull Request resolved: #3113
1 parent 20bf0db commit 6f7f55f

File tree

5 files changed

+352
-0
lines changed

5 files changed

+352
-0
lines changed

backends/vulkan/README.md

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# ExecuTorch Vulkan Delegate
2+
3+
The ExecuTorch Vulkan delegate is a native GPU delegate for ExecuTorch that is
4+
built on top of the cross-platform Vulkan GPU API standard. It is primarily
5+
designed to leverage the GPU to accelerate model inference on Android devices,
6+
but can be used on any platform that supports an implementation of Vulkan:
7+
laptops, servers, and edge devices.
8+
9+
::::{note}
10+
The Vulkan delegate is currently under active development, and its components
11+
are subject to change.
12+
::::
13+
14+
## What is Vulkan?
15+
16+
Vulkan is a low-level GPU API specification developed as a successor to OpenGL.
17+
It is designed to offer developers a more explicit control over GPUs compared to
18+
previous specifications in order to reduce overhead and maximize the
19+
capabilities of the modern graphics hardware.
20+
21+
Vulkan has been widely adopted among GPU vendors, and most modern GPUs (both
22+
desktop and mobile) in the market support Vulkan. Vulkan is also included in
23+
Android from Android 7.0 onwards.
24+
25+
**Note that Vulkan is a GPU API, not a GPU Math Library**. That is to say it
26+
provides a way to execute compute and graphics operations on a GPU, but does not
27+
come with a built-in library of performant compute kernels.
28+
29+
## The Vulkan Compute Library
30+
31+
The ExecuTorch Vulkan Delegate is a wrapper around a standalone runtime known as
32+
the **Vulkan Compute Library**. The aim of the Vulkan Compute Library is to
33+
provide GPU implementations for PyTorch operators via GLSL compute shaders to be
34+
executed using Vulkan.
35+
36+
The Vulkan Compute Library is a fork/iteration of the [PyTorch Vulkan Backend](https://pytorch.org/tutorials/prototype/vulkan_workflow.html).
37+
The core components of the PyTorch Vulkan backend were forked into ExecuTorch
38+
and adapted for an AOT graph-mode style of model inference (as opposed to
39+
PyTorch which adopted an eager execution style of model inference).
40+
41+
The components of the Vulkan Compute Library are contained in the
42+
`executorch/backends/vulkan/runtime/` directory. The core components are listed
43+
and described below:
44+
45+
```
46+
runtime/
47+
├── api/ .................... Wrapper API around Vulkan to manage Vulkan objects
48+
└── graph/ .................. ComputeGraph class which implements graph mode inference
49+
└── ops/ ................ Base directory for operator implementations
50+
├── glsl/ ........... GLSL compute shaders
51+
│ ├── *.glsl
52+
│ └── conv2d.glsl
53+
└── impl/ ........... C++ code to dispatch GPU compute shaders
54+
├── *.cpp
55+
└── Conv2d.cpp
56+
```
57+
58+
## Features
59+
60+
The Vulkan delegate currently supports the following features:
61+
62+
* **Memory Planning**
63+
* Intermediate tensors whose lifetimes do not overlap will share memory allocations. This reduces the peak memory usage of model inference.
64+
* **Capability Based Partitioning**:
65+
* A graph can be partially lowered to the Vulkan delegate via a partitioner, which will identify nodes (i.e. operators) that are supported by the Vulkan delegate and lower only supported subgraphs
66+
* **Support for upper-bound dynamic shapes**:
67+
* Tensors can change shape between inferences as long as its current shape is smaller than the bounds specified during lowering
68+
69+
In addition to increasing operator coverage, the following features are
70+
currently in development:
71+
72+
* **Quantization Support**
73+
* We are currently working on support for 8-bit dynamic quantization, with plans to extend to other quantization schemes in the future.
74+
* **Memory Layout Management**
75+
* Memory layout is an important factor to optimizing performance. We plan to introduce graph passes to introduce memory layout transitions throughout a graph to optimize memory-layout sensitive operators such as Convolution and Matrix Multiplication.
76+
* **Selective Build**
77+
* We plan to make it possible to control build size by selecting which operators/shaders you want to build with
78+
79+
## End to End Example
80+
81+
To further understand the features of the Vulkan Delegate and how to use it,
82+
consider the following end to end example with MobileNet V2.
83+
84+
### Compile and lower a model to the Vulkan Delegate
85+
86+
Assuming ExecuTorch has been set up and installed, the following script can be
87+
used to produce a lowered MobileNet V2 model as `vulkan_mobilenetv2.pte`.
88+
89+
```
90+
import torch
91+
import torchvision.models as models
92+
93+
from torch.export import export, ExportedProgram
94+
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
95+
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
96+
from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge
97+
from executorch.exir.backend.backend_api import to_backend
98+
99+
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
100+
sample_inputs = (torch.randn(1, 3, 224, 224), )
101+
102+
exported_program: ExportedProgram = export(mobilenet_v2, sample_inputs)
103+
edge: EdgeProgramManager = to_edge(exported_program)
104+
105+
# Lower the model to Vulkan backend
106+
edge = edge.to_backend(VulkanPartitioner())
107+
108+
exec_prog = edge.to_executorch()
109+
110+
with open("vulkan_mobilenetv2.pte", "wb") as file:
111+
exec_prog.write_to_file(file)
112+
```
113+
114+
Like other ExecuTorch delegates, a model can be lowered to the Vulkan Delegate
115+
using the `to_backend()` API. The Vulkan Delegate implements the
116+
`VulkanPartitioner` class which identifies nodes (i.e. operators) in the graph
117+
that are supported by the Vulkan delegate, and separates compatible sections of
118+
the model to be executed on the GPU.
119+
120+
This means the a model can be lowered to the Vulkan delegate even if it contains
121+
some unsupported operators. This will just mean that only parts of the graph
122+
will be executed on the GPU.
123+
124+
125+
::::{note}
126+
The [Vulkan partitioner code](https://github.com/pytorch/executorch/blob/main/backends/vulkan/partitioner/vulkan_partitioner.py)
127+
can be inspected to examine which ops are currently implemented in the Vulkan
128+
delegate.
129+
::::
130+
131+
### Build Vulkan Delegate libraries
132+
133+
The easiest way to build and test the Vulkan Delegate is to build for Android
134+
and test on a local Android device. Android devices have built in support for
135+
Vulkan, and the Android NDK ships with a GLSL compiler, which is needed to
136+
compile the Vulkan Compute Library's GLSL compute shaders.
137+
138+
The Vulkan Delegate libraries can be built by setting `-DEXECUTORCH_BUILD_VULKAN=ON`
139+
when building with CMake.
140+
141+
First, make sure that you have the Android NDK installed - Android NDK r25c is
142+
recommended. The Android SDK should also be installed so that you have access
143+
to `adb`.
144+
145+
```shell
146+
# Recommended version is Android NDK r25c.
147+
export ANDROID_NDK=<path_to_ndk>
148+
# Select an appropriate Android ABI
149+
export ANDROID_ABI=arm64-v8a
150+
# All subsequent commands should be performed from ExecuTorch repo root
151+
cd <path_to_executorch_root>
152+
# Make sure adb works
153+
adb --version
154+
```
155+
156+
To build and install ExecuTorch libraries (for Android) with the Vulkan
157+
Delegate:
158+
159+
```shell
160+
# From executorch root directory
161+
(rm -rf cmake-android-out && \
162+
pp cmake . -DCMAKE_INSTALL_PREFIX=cmake-android-out \
163+
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
164+
-DANDROID_ABI=$ANDROID_ABI \
165+
-DEXECUTORCH_BUILD_VULKAN=ON \
166+
-DPYTHON_EXECUTABLE=python \
167+
-Bcmake-android-out && \
168+
cmake --build cmake-android-out -j16 --target install)
169+
```
170+
171+
### Run the Vulkan model on device
172+
173+
::::{note}
174+
Since operator support is currently limited, only binary arithmetic operators
175+
will run on the GPU. Expect inference to be slow as the majority of operators
176+
are being executed via Portable operators.
177+
::::
178+
179+
Now, the partially delegated model can be executed (partially) on your device's
180+
GPU!
181+
182+
```shell
183+
# Build a model runner binary linked with the Vulkan delegate libs
184+
cmake --build cmake-android-out --target vulkan_executor_runner -j32
185+
186+
# Push model to device
187+
adb push vulkan_mobilenetv2.pte /data/local/tmp/vulkan_mobilenetv2.pte
188+
# Push binary to device
189+
adb push cmake-android-out/backends/vulkan/vulkan_executor_runner /data/local/tmp/runner_bin
190+
191+
# Run the model
192+
adb shell /data/local/tmp/runner_bin --model_path /data/local/tmp/vulkan_mobilenetv2.pte
193+
```

backends/vulkan/docs/android_demo.md

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Building and Running ExecuTorch with the Vulkan Backend
2+
3+
The [ExecuTorch Vulkan Delegate](./native-delegates-executorch-vulkan-delegate.md)
4+
is a native GPU delegate for ExecuTorch.
5+
6+
<!----This will show a grid card on the page----->
7+
::::{grid} 2
8+
:::{grid-item-card} What you will learn in this tutorial:
9+
:class-card: card-content
10+
* How to export the Stories 110M parameter model with partial GPU delegation
11+
* How to execute the partially delegated model on Android
12+
:::
13+
:::{grid-item-card} Prerequisites:
14+
:class-card: card-prerequisites
15+
* Follow [**Setting up ExecuTorch**](./getting-started-setup.md)
16+
* Follow [**Setting up the ExecuTorch LLaMA Android Demo App**](./llm/llama-demo-android.md)
17+
:::
18+
::::
19+
20+
## Prerequisites
21+
22+
Note that all the steps below should be performed from the ExecuTorch repository
23+
root directory, and assumes that you have gone through the steps of setting up
24+
ExecuTorch.
25+
26+
You should also refer to the **Prerequisites** section of the [**Setting up the ExecuTorch LLaMA Android Demo App**](./llm/llama-demo-android.md)
27+
Tutorial in order to install the specified versions of the Android NDK and the
28+
Android SDK.
29+
30+
```shell
31+
# Recommended version is Android NDK r25c.
32+
export ANDROID_NDK=<path_to_ndk>
33+
# Select an appropriate Android ABI
34+
export ANDROID_ABI=arm64-v8a
35+
# All subsequent commands should be performed from ExecuTorch repo root
36+
cd <path_to_executorch_root>
37+
# Make sure adb works
38+
adb --version
39+
```
40+
41+
## Lowering the Stories 110M model to Vulkan
42+
43+
::::{note}
44+
The resultant model will only be partially delegated to the Vulkan backend. In
45+
particular, only binary arithmetic operators (`aten.add`, `aten.sub`,
46+
`aten.mul`, `aten.div`) and the matrix multiplication operator (`aten.mm`) will
47+
be executed on the GPU via the Vulkan delegate. The rest of the model will be
48+
executed using Portable operators. This is because the Vulkan delegate is still
49+
early in development and currently has limited operator coverage.
50+
::::
51+
52+
First, download `stories110M.pt` and `tokenizer.model` from Github:
53+
54+
```shell
55+
wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt"
56+
wget "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model"
57+
```
58+
59+
Next, create the params file:
60+
61+
```shell
62+
echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json
63+
```
64+
65+
Then, create a tokenizer binary file:
66+
67+
```shell
68+
python -m examples.models.llama2.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
69+
```
70+
71+
Finally, export the `stories110M.pt` file into an ExecuTorch program:
72+
73+
```shell
74+
python -m examples.models.llama2.export_llama -c stories110M.pt -p params.json --vulkan
75+
```
76+
77+
A `vulkan_llama2.pte` file should have been created as a result of the last step.
78+
79+
Push the tokenizer binary and `vulkan_llama2.pte` onto your Android device:
80+
81+
```shell
82+
adb mkdir /data/local/tmp/llama/
83+
adb push tokenizer.bin /data/local/tmp/llama/
84+
adb push vulkan_llama2.pte /data/local/tmp/llama/
85+
```
86+
87+
## Build and Run the LLaMA runner binary on Android
88+
89+
First, build and install ExecuTorch libraries, then build the LLaMA runner
90+
binary using the Android NDK toolchain.
91+
92+
```shell
93+
(rm -rf cmake-android-out && \
94+
cmake . -DCMAKE_INSTALL_PREFIX=cmake-android-out \
95+
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
96+
-DANDROID_ABI=$ANDROID_ABI \
97+
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
98+
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
99+
-DEXECUTORCH_BUILD_VULKAN=ON \
100+
-DEXECUTORCH_BUILD_OPTIMIZED=ON \
101+
-DPYTHON_EXECUTABLE=python \
102+
-Bcmake-android-out && \
103+
cmake --build cmake-android-out -j16 --target install)
104+
105+
# Build LLaMA Runner library
106+
(rm -rf cmake-android-out/examples/models/llama2 && \
107+
cmake examples/models/llama2 \
108+
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
109+
-DANDROID_ABI=$ANDROID_ABI \
110+
-DCMAKE_INSTALL_PREFIX=cmake-android-out \
111+
-DPYTHON_EXECUTABLE=python \
112+
-Bcmake-android-out/examples/models/llama2 && \
113+
cmake --build cmake-android-out/examples/models/llama2 -j16)
114+
```
115+
116+
Finally, push and run the llama runner binary on your Android device.
117+
118+
```shell
119+
adb push cmake-android-out/examples/models/llama2/llama_main /data/local/tmp/llama_main
120+
121+
adb shell /data/local/tmp/llama_main \
122+
--model_path=/data/local/tmp/llama/vulkan_llama2.pte \
123+
--tokenizer_path=/data/local/tmp/llama/tokenizer.bin \
124+
--prompt "hi" \--temperature=0
125+
```
126+
127+
The following output will be produced:
128+
129+
```
130+
hippo named Hippy lived in a big pond. Hippy was a very happy hippo. He liked to play...
131+
```
132+
133+
## Running with the LLaMA Android Demo App
134+
135+
It is also possible to run the partially delegated Vulkan model inside the LLaMA
136+
Android demo app.
137+
138+
First, make some modifications to the Android app setup script to make sure that
139+
the Vulkan backend is built when building and installing ExecuTorch libraries:
140+
141+
```shell
142+
# Run from executorch root directory. You can also edit this in a code editor
143+
sed -i 's/-DEXECUTORCH_BUILD_XNNPACK=ON/-DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_VULKAN=ON/g' examples/demo-apps/android/LlamaDemo/setup.sh
144+
```
145+
146+
Then, Follow the instructions at [**Setting up the ExecuTorch LLaMA Android Demo App**](./llm/llama-demo-android.md)
147+
to build and run the demo application on your Android device. Once the app
148+
starts up, you can load and run the `vulkan_llama2.pte` model with the app.

docs/source/build-run-vulkan.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
```{include} ../../backends/vulkan/docs/android_demo.md

docs/source/index.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ Topics in this section will help you get started with ExecuTorch.
100100
demo-apps-android
101101
examples-end-to-end-to-lower-model-to-delegate
102102
tutorial-xnnpack-delegate-lowering
103+
build-run-vulkan
103104
..
104105
Alphabetical by backend name. Be sure to keep the same order in the
105106
customcarditem entries below.
@@ -183,6 +184,7 @@ Topics in this section will help you get started with ExecuTorch.
183184
:hidden:
184185

185186
native-delegates-executorch-xnnpack-delegate
187+
native-delegates-executorch-vulkan-delegate
186188
backend-delegates-integration
187189
backend-delegates-dependencies
188190

@@ -262,6 +264,13 @@ ExecuTorch tutorials.
262264
:link: tutorial-xnnpack-delegate-lowering.html
263265
:tags: Export,Backend,Delegation,Quantization,XNNPACK
264266

267+
.. customcarditem::
268+
:header: Building and Running ExecuTorch with Vulkan Backend
269+
:card_description: A tutorial that walks you through the process of building ExecuTorch with Vulkan Backend
270+
:image: _static/img/generic-pytorch-logo.png
271+
:link: build-run-vulkan.html
272+
:tags: Export,Backend,Delegation,Vulkan
273+
265274
..
266275
Alphabetical by backend name. Be sure to keep the same order in the Tutorials
267276
toctree entry above.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
```{include} ../../backends/vulkan/README.md

0 commit comments

Comments
 (0)