|
3 | 3 |
|
4 | 4 | At the last stage of [ExecuTorch model exporting](./export-overview.md), we lower the operators in the dialect to the _out variants_ of the [core ATen operators](./ir-ops-set-definition.md). Then we serialize these operator names into the model artifact. During runtime execution, for each operator name we will need to find the actual _kernels_, i.e., the C++ functions that do the heavy-lifting calculations and return results.
|
5 | 5 |
|
6 |
| -Portable kernel library is the in-house default kernel library, it’s easy to use and portable for most of the target backends. However it’s not optimized for performance, because it’s not specialized for any certain target. Therefore we provide kernel registration APIs for ExecuTorch users to easily register their own optimized kernels. |
| 6 | +## Kernel Libraries |
| 7 | +### First-party kernel libraries: |
7 | 8 |
|
| 9 | +**[Portable kernel library](https://github.com/pytorch/executorch/tree/main/kernels/portable)** is the in-house default kernel library that covers most of the core ATen operators. It’s easy to use/read and is written in portable C++17. However it’s not optimized for performance, because it’s not specialized for any certain target. Therefore we provide kernel registration APIs for ExecuTorch users to easily register their own optimized kernels. |
8 | 10 |
|
9 |
| -## Design Principles |
| 11 | +**[Optimized kernel library](https://github.com/pytorch/executorch/tree/main/kernels/optimized)** specializes on performance for some of the operators, leveraging existing third party libraries such as [EigenBLAS](https://gitlab.com/libeigen/eigen). This works best along with the portable kernel library, with a good balance on portability and performance. One example of combining these two libraries can be found [here](https://github.com/pytorch/executorch/blob/main/configurations/CMakeLists.txt). |
10 | 12 |
|
11 |
| -**What do we support?** On the operator coverage side, the kernel registration APIs allow users to register kernels for all core ATen ops as well as custom ops, as long as the custom ops schemas are specified. |
| 13 | +**[Quantized kernel library](https://github.com/pytorch/executorch/tree/main/kernels/quantized)** implements operators for quantization and dequantization. These are out of core ATen operators but are vital to most of the production use cases. |
12 | 14 |
|
13 |
| -Notice that we also support partial kernels, for example the kernel only supports a subset of tensor dtypes and/or dim orders. |
| 15 | +### Custom kernel libraries: |
14 | 16 |
|
15 |
| -**Kernel contract**: kernels need to comply with the following requirements: |
| 17 | +**Custom kernels implementing core ATen ops**. Even though we don't have an internal example for custom kernels for core ATen ops, the optimized kernel library can be viewed as a good example. We have optimized [`add.out`](https://github.com/pytorch/executorch/blob/main/kernels/optimized/cpu/op_add.cpp) and a portable [`add.out`](https://github.com/pytorch/executorch/blob/main/kernels/portable/cpu/op_add.cpp). When user is combining these two libraries, we provide APIs to choose which kernel to use for `add.out`. In order to author and use custom kernels implementing core ATen ops, using the [YAML based approach](#yaml-entry-for-core-aten-op-out-variant) is recommended, because it provides full fledged support on |
| 18 | + 1. combining kernel libraries and define fallback kernels; |
| 19 | + 2. using selective build to minimize the kernel size. |
| 20 | + |
| 21 | +A **[Custom operator](https://github.com/pytorch/executorch/tree/main/extension/llm/custom_ops)** is any operator that an ExecuTorch user defines outside of PyTorch's [`native_functions.yaml`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml). |
| 22 | + |
| 23 | +## Operator & Kernel Contract |
| 24 | + |
| 25 | +All the kernels mentioned above, whether they are in-house or customized, should comply with the following requirements: |
16 | 26 |
|
17 | 27 | * Match the calling convention derived from operator schema. The kernel registration API will generate headers for the custom kernels as references.
|
18 |
| -* Satisfy the dtype constraints defined in edge dialect. For tensors with certain dtypes as arguments, the result of a custom kernel needs to match the expected dtypes. The constraints are available in edge dialect ops. |
19 |
| -* Gives correct result. We will provide a testing framework to automatically test the custom kernels. |
| 28 | +* Satisfy the dtype constraints defined in edge dialect. For tensors with certain dtypes as arguments, the result of a custom kernel needs to match the expected dtypes. The constraints are available in edge dialect ops. |
| 29 | +* Give correct result. We will provide a testing framework to automatically test the custom kernels. |
| 30 | + |
| 31 | + |
| 32 | +## APIs |
| 33 | + |
| 34 | +These are the APIs available to register kernels/custom kernels/custom ops into ExecuTorch: |
| 35 | + |
| 36 | +* [YAML Entry API](#yaml-entry-api-high-level-architecture) |
| 37 | + - [for core ATen op with custom kernels](#yaml-entry-api-for-core-aten-op-out-variant) |
| 38 | + - [for custom ops](#yaml-entry-api-for-custom-ops) |
| 39 | + - [CMake Macros](#cmake-macros) |
| 40 | +* C++ API |
| 41 | + - [for custom ops](#c-api-for-custom-ops) |
| 42 | + - [CMake Example](#compile-and-link-the-custom-kernel) |
| 43 | + |
| 44 | +If it's not clear which API to use, please see [Best Practices](#custom-ops-api-best-practices). |
| 45 | + |
20 | 46 |
|
21 | 47 |
|
22 |
| -## High Level Architecture |
| 48 | +### YAML Entry API High Level Architecture |
23 | 49 |
|
24 | 50 | 
|
25 | 51 |
|
26 | 52 | ExecuTorch users are asked to provide:
|
27 | 53 |
|
28 | 54 | 1. the custom kernel library with C++ implementations
|
29 | 55 |
|
30 |
| -2. a yaml file associated with the library that describes what operators are being implemented by this library. For partial kernels, the yaml file also contains information on the dtypes and dim orders supported by the kernel. More details in the API section. |
| 56 | +2. a YAML file associated with the library that describes what operators are being implemented by this library. For partial kernels, the yaml file also contains information on the dtypes and dim orders supported by the kernel. More details in the API section. |
31 | 57 |
|
32 | 58 |
|
33 |
| -### Workflow |
| 59 | +### YAML Entry API Workflow |
34 | 60 |
|
35 | 61 | At build time, the yaml files associated with kernel libraries will be passed to the _kernel resolver_ along with the model op info (see selective build doc) and the outcome is a mapping between a combination of operator names and tensor metadata, to kernel symbols. Then codegen tools will use this mapping to generate C++ bindings that connect the kernels to ExecuTorch runtime. ExecuTorch users need to link this generated library into their application to use these kernels.
|
36 | 62 |
|
37 | 63 | At static object initialization time, kernels will be registered into the ExecuTorch kernel registry.
|
38 | 64 |
|
39 | 65 | At runtime initialization stage, ExecuTorch will use the operator name and argument metadata as a key to lookup for the kernels. For example, with “aten::add.out” and inputs being float tensors with dim order (0, 1, 2, 3), ExecuTorch will go into the kernel registry and lookup for a kernel that matches the name and the input metadata.
|
40 | 66 |
|
41 |
| - |
42 |
| -## APIs |
43 |
| - |
44 |
| -There are two sets of APIs: yaml files that describe kernel - operator mappings and codegen tools to consume these mappings. |
45 |
| - |
46 |
| - |
47 |
| -### Yaml Entry for Core ATen Op Out Variant |
| 67 | +### YAML Entry API for Core ATen Op Out Variant |
48 | 68 |
|
49 | 69 | Top level attributes:
|
50 | 70 |
|
51 |
| - |
52 |
| - |
53 | 71 | * `op` (if the operator appears in `native_functions.yaml`) or `func` for custom operator. The value for this key needs to be the full operator name (including overload name) for `op` key, or a full operator schema (namespace, operator name, operator overload name and schema string), if we are describing a custom operator. For schema syntax please refer to this [instruction](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md).
|
54 | 72 | * `kernels`: defines kernel information. It consists of `arg_meta` and `kernel_name`, which are bound together to describe "for input tensors with these metadata, use this kernel".
|
55 | 73 | * `type_alias`(optional): we are giving aliases to possible dtype options. `T0: [Double, Float]` means `T0` can be one of `Double` or `Float`.
|
@@ -86,86 +104,9 @@ ATen operator with a dtype/dim order specialized kernel (works for `Double` dtyp
|
86 | 104 | kernel_name: torch::executor::add_out
|
87 | 105 |
|
88 | 106 | ```
|
89 |
| -### Custom Ops C++ API |
90 |
| - |
91 |
| -For a custom kernel that implements a custom operator, we provides 2 ways to register it into ExecuTorch runtime: |
92 |
| -1. Using `EXECUTORCH_LIBRARY` and `WRAP_TO_ATEN` C++ macros, covered by this section. |
93 |
| -2. Using `functions.yaml` and codegen'd C++ libraries, covered by [next section](#custom-ops-yaml-entry). |
94 |
| - |
95 |
| -Please refer to [Custom Ops Best Practices](#custom-ops-api-best-practices) on which API to use. |
96 |
| - |
97 |
| -The first option requires C++17 and doesn't have selective build support yet, but it's faster than the second option where we have to go through yaml authoring and build system tweaking. |
98 |
| - |
99 |
| -The first option is particularly suitable for fast prototyping but can also be used in production. |
100 |
| - |
101 |
| -Similar to `TORCH_LIBRARY`, `EXECUTORCH_LIBRARY` takes the operator name and the C++ function name and register them into ExecuTorch runtime. |
102 |
| - |
103 |
| -#### Prepare custom kernel implementation |
104 |
| - |
105 |
| -Define your custom operator schema for both functional variant (used in AOT compilation) and out variant (used in ExecuTorch runtime). The schema needs to follow PyTorch ATen convention (see native_functions.yaml). For example: |
106 |
| - |
107 |
| -```yaml |
108 |
| -custom_linear(Tensor weight, Tensor input, Tensor(?) bias) -> Tensor |
109 |
| -custom_linear.out(Tensor weight, Tensor input, Tensor(?) bias, *, Tensor(a!) out) -> Tensor(a!) |
110 |
| -``` |
111 |
| - |
112 |
| -Then write your custom kernel according to the schema using ExecuTorch types, along with APIs to register to ExecuTorch runtime: |
113 |
| - |
114 |
| - |
115 |
| -```c++ |
116 |
| -// custom_linear.h/custom_linear.cpp |
117 |
| -#include <executorch/runtime/kernel/kernel_includes.h> |
118 |
| -Tensor& custom_linear_out(const Tensor& weight, const Tensor& input, optional<Tensor> bias, Tensor& out) { |
119 |
| - // calculation |
120 |
| - return out; |
121 |
| -} |
122 |
| -``` |
123 |
| -#### Use a C++ macro to register it into PyTorch & ExecuTorch |
124 |
| - |
125 |
| -Append the following line in the example above: |
126 |
| -```c++ |
127 |
| -// custom_linear.h/custom_linear.cpp |
128 |
| -// opset namespace myop |
129 |
| -EXECUTORCH_LIBRARY(myop, "custom_linear.out", custom_linear_out); |
130 |
| -``` |
131 |
| - |
132 |
| -Now we need to write some wrapper for this op to show up in PyTorch, but don’t worry we don’t need to rewrite the kernel. Create a separate .cpp for this purpose: |
133 |
| - |
134 |
| -```c++ |
135 |
| -// custom_linear_pytorch.cpp |
136 |
| -#include "custom_linear.h" |
137 |
| -#include <torch/library.h> |
138 |
| -
|
139 |
| -at::Tensor custom_linear(const at::Tensor& weight, const at::Tensor& input, std::optional<at::Tensor> bias) { |
140 |
| - // initialize out |
141 |
| - at::Tensor out = at::empty({weight.size(1), input.size(1)}); |
142 |
| - // wrap kernel in custom_linear.cpp into ATen kernel |
143 |
| - WRAP_TO_ATEN(custom_linear_out, 3)(weight, input, bias, out); |
144 |
| - return out; |
145 |
| -} |
146 |
| -// standard API to register ops into PyTorch |
147 |
| -TORCH_LIBRARY(myop, m) { |
148 |
| - m.def("custom_linear(Tensor weight, Tensor input, Tensor(?) bias) -> Tensor", custom_linear); |
149 |
| - m.def("custom_linear.out(Tensor weight, Tensor input, Tensor(?) bias, *, Tensor(a!) out) -> Tensor(a!)", WRAP_TO_ATEN(custom_linear_out, 3)); |
150 |
| -} |
151 |
| -``` |
152 |
| - |
153 |
| -#### Compile and link the custom kernel |
154 |
| - |
155 |
| -Link it into ExecuTorch runtime: In our `CMakeLists.txt`` that builds the binary/application, we just need to add custom_linear.h/cpp into the binary target. We can build a dynamically loaded library (.so or .dylib) and link it as well. |
156 |
| - |
157 |
| -Link it into PyTorch runtime: We need to package custom_linear.h, custom_linear.cpp and custom_linear_pytorch.cpp into a dynamically loaded library (.so or .dylib) and load it into our python environment. One way of doing this is: |
158 |
| - |
159 |
| -```python |
160 |
| -import torch |
161 |
| -torch.ops.load_library("libcustom_linear.so/dylib") |
162 |
| -
|
163 |
| -# Now we have access to the custom op, backed by kernel implemented in custom_linear.cpp. |
164 |
| -op = torch.ops.myop.custom_linear.default |
165 |
| -``` |
166 | 107 |
|
167 | 108 |
|
168 |
| -### Custom Ops Yaml Entry |
| 109 | +### YAML Entry API for Custom Ops |
169 | 110 |
|
170 | 111 | As mentioned above, this option provides more support in terms of selective build and features such as merging operator libraries.
|
171 | 112 |
|
@@ -215,14 +156,11 @@ ExecuTorch does not support all of the argument types that core PyTorch supports
|
215 | 156 | * List<Optional<Type>>
|
216 | 157 | * Optional<List<Type>>
|
217 | 158 |
|
218 |
| - |
219 |
| -### Build Tool Macros |
| 159 | +#### CMake Macros |
220 | 160 |
|
221 | 161 | We provide build time macros to help users to build their kernel registration library. The macro takes the yaml file describing the kernel library as well as model operator metadata, and packages the generated C++ bindings into a C++ library. The macro is available on CMake.
|
222 | 162 |
|
223 | 163 |
|
224 |
| -#### CMake |
225 |
| - |
226 | 164 | `generate_bindings_for_kernels(FUNCTIONS_YAML functions_yaml CUSTOM_OPS_YAML custom_ops_yaml)` takes a yaml file for core ATen op out variants and also a yaml file for custom ops, generate C++ bindings for kernel registration. It also depends on the selective build artifact generated by `gen_selected_ops()`, see selective build doc for more information. Then `gen_operators_lib` will package those bindings to be a C++ library. As an example:
|
227 | 165 | ```cmake
|
228 | 166 | # SELECT_OPS_LIST: aten::add.out,aten::mm.out
|
@@ -263,6 +201,103 @@ And out fallback:
|
263 | 201 |
|
264 | 202 | The merged yaml will have the entry in functions.yaml.
|
265 | 203 |
|
| 204 | +### C++ API for Custom Ops |
| 205 | + |
| 206 | +Unlike the YAML entry API, the C++ API only uses C++ macros `EXECUTORCH_LIBRARY` and `WRAP_TO_ATEN` for kernel registration, also without selective build support. It makes this API faster in terms of development speed, since users don't have to do YAML authoring and build system tweaking. |
| 207 | + |
| 208 | +Please refer to [Custom Ops Best Practices](#custom-ops-api-best-practices) on which API to use. |
| 209 | + |
| 210 | +Similar to [`TORCH_LIBRARY`](https://pytorch.org/cppdocs/library.html#library_8h_1a0bd5fb09d25dfb58e750d712fc5afb84) in PyTorch, `EXECUTORCH_LIBRARY` takes the operator name and the C++ function name and register them into ExecuTorch runtime. |
| 211 | + |
| 212 | +#### Prepare custom kernel implementation |
| 213 | + |
| 214 | +Define your custom operator schema for both functional variant (used in AOT compilation) and out variant (used in ExecuTorch runtime). The schema needs to follow PyTorch ATen convention (see `native_functions.yaml`). For example: |
| 215 | + |
| 216 | +```yaml |
| 217 | +custom_linear(Tensor weight, Tensor input, Tensor(?) bias) -> Tensor |
| 218 | +custom_linear.out(Tensor weight, Tensor input, Tensor(?) bias, *, Tensor(a!) out) -> Tensor(a!) |
| 219 | +``` |
| 220 | + |
| 221 | +Then write your custom kernel according to the schema using ExecuTorch types, along with APIs to register to ExecuTorch runtime: |
| 222 | + |
| 223 | + |
| 224 | +```c++ |
| 225 | +// custom_linear.h/custom_linear.cpp |
| 226 | +#include <executorch/runtime/kernel/kernel_includes.h> |
| 227 | +Tensor& custom_linear_out(const Tensor& weight, const Tensor& input, optional<Tensor> bias, Tensor& out) { |
| 228 | + // calculation |
| 229 | + return out; |
| 230 | +} |
| 231 | +``` |
| 232 | +#### Use a C++ macro to register it into ExecuTorch |
| 233 | + |
| 234 | +Append the following line in the example above: |
| 235 | +```c++ |
| 236 | +// custom_linear.h/custom_linear.cpp |
| 237 | +// opset namespace myop |
| 238 | +EXECUTORCH_LIBRARY(myop, "custom_linear.out", custom_linear_out); |
| 239 | +``` |
| 240 | + |
| 241 | +Now we need to write some wrapper for this op to show up in PyTorch, but don’t worry we don’t need to rewrite the kernel. Create a separate .cpp for this purpose: |
| 242 | + |
| 243 | +```c++ |
| 244 | +// custom_linear_pytorch.cpp |
| 245 | +#include "custom_linear.h" |
| 246 | +#include <torch/library.h> |
| 247 | +
|
| 248 | +at::Tensor custom_linear(const at::Tensor& weight, const at::Tensor& input, std::optional<at::Tensor> bias) { |
| 249 | + // initialize out |
| 250 | + at::Tensor out = at::empty({weight.size(1), input.size(1)}); |
| 251 | + // wrap kernel in custom_linear.cpp into ATen kernel |
| 252 | + WRAP_TO_ATEN(custom_linear_out, 3)(weight, input, bias, out); |
| 253 | + return out; |
| 254 | +} |
| 255 | +// standard API to register ops into PyTorch |
| 256 | +TORCH_LIBRARY(myop, m) { |
| 257 | + m.def("custom_linear(Tensor weight, Tensor input, Tensor(?) bias) -> Tensor", custom_linear); |
| 258 | + m.def("custom_linear.out(Tensor weight, Tensor input, Tensor(?) bias, *, Tensor(a!) out) -> Tensor(a!)", WRAP_TO_ATEN(custom_linear_out, 3)); |
| 259 | +} |
| 260 | +``` |
| 261 | + |
| 262 | +#### Compile and link the custom kernel |
| 263 | + |
| 264 | +Link it into ExecuTorch runtime: In our `CMakeLists.txt` that builds the binary/application, we need to add custom_linear.h/cpp into the binary target. We can build a dynamically loaded library (.so or .dylib) and link it as well. |
| 265 | + |
| 266 | +Here's an example to do it: |
| 267 | + |
| 268 | +```cmake |
| 269 | +# For target_link_options_shared_lib |
| 270 | +include(${EXECUTORCH_ROOT}/build/Utils.cmake) |
| 271 | +
|
| 272 | +# Add a custom op library |
| 273 | +add_library(custom_op_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/custom_op.cpp) |
| 274 | +
|
| 275 | +# Include the header |
| 276 | +target_include_directory(custom_op_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) |
| 277 | +
|
| 278 | +# Link ExecuTorch library |
| 279 | +target_link_libraries(custom_op_lib PUBLIC executorch) |
| 280 | +
|
| 281 | +# Define a binary target |
| 282 | +add_executable(custom_op_runner PUBLIC main.cpp) |
| 283 | +
|
| 284 | +# Link this library with --whole-archive !! IMPORTANT !! this is to avoid the operators being stripped by linker |
| 285 | +target_link_options_shared_lib(custom_op_lib) |
| 286 | +
|
| 287 | +# Link custom op lib |
| 288 | +target_link_libraries(custom_op_runner PUBLIC custom_op_lib) |
| 289 | +
|
| 290 | +``` |
| 291 | + |
| 292 | +Link it into the PyTorch runtime: We need to package custom_linear.h, custom_linear.cpp and custom_linear_pytorch.cpp into a dynamically loaded library (.so or .dylib) and load it into our python environment. One way of doing this is: |
| 293 | + |
| 294 | +```python |
| 295 | +import torch |
| 296 | +torch.ops.load_library("libcustom_linear.so/dylib") |
| 297 | +
|
| 298 | +# Now we have access to the custom op, backed by kernel implemented in custom_linear.cpp. |
| 299 | +op = torch.ops.myop.custom_linear.default |
| 300 | +``` |
266 | 301 |
|
267 | 302 | ### Custom Ops API Best Practices
|
268 | 303 |
|
|
0 commit comments