Skip to content

Commit 64cf836

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Use new API to register custom ExecuTorch kernels into ATen (#2937)
Summary: Pull Request resolved: #2937 Retry of D55713944 Use `WRAP_TO_ATEN` to register custom ExecuTorch kernel to PyTorch. This PR added installation logic to `libcustom_ops_aot_lib.so` in `setup.py`. This is to make sure we can build `libcustom_ops_aot_lib.so` and install it to the correct position (`<site-packages>/executorch/examples/models/llama2/custom_ops/libcustom_ops_aot_lib.so`) and then it can be loaded by `torch.ops.load_library`. Reviewed By: lucylq Differential Revision: D55907749
1 parent 971fec7 commit 64cf836

File tree

9 files changed

+217
-112
lines changed

9 files changed

+217
-112
lines changed

CMakeLists.txt

+20-1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ option(EXECUTORCH_BUILD_COREML "Build the Core ML backend" OFF)
144144

145145
option(EXECUTORCH_BUILD_CUSTOM "Build the custom kernels" OFF)
146146

147+
option(EXECUTORCH_BUILD_CUSTOM_OPS_AOT "Build the custom ops lib for AOT" OFF)
148+
147149
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "Build the Data Loader extension"
148150
OFF)
149151

@@ -185,12 +187,19 @@ cmake_dependent_option(
185187
cmake_dependent_option(EXECUTORCH_BUILD_CPUINFO "Build cpuinfo library." ON
186188
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF)
187189

190+
if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT)
191+
set(EXECUTORCH_BUILD_CUSTOM ON)
192+
endif()
193+
188194
if(EXECUTORCH_BUILD_CUSTOM)
189195
set(EXECUTORCH_BUILD_OPTIMIZED ON)
190196
endif()
191197

192198
if(EXECUTORCH_BUILD_CPUINFO)
193199
# --- cpuinfo
200+
set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG
201+
${CMAKE_POSITION_INDEPENDENT_CODE})
202+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
194203
set(CPUINFO_SOURCE_DIR "backends/xnnpack/third-party/cpuinfo")
195204
set(CPUINFO_BUILD_TOOLS
196205
OFF
@@ -212,10 +221,15 @@ if(EXECUTORCH_BUILD_CPUINFO)
212221
CACHE STRING "")
213222
set(CLOG_SOURCE_DIR "${CPUINFO_SOURCE_DIR}/deps/clog")
214223
add_subdirectory("${CPUINFO_SOURCE_DIR}")
224+
set(CMAKE_POSITION_INDEPENDENT_CODE
225+
${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG})
215226
endif()
216227

217228
if(EXECUTORCH_BUILD_PTHREADPOOL)
218229
# --- pthreadpool
230+
set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG
231+
${CMAKE_POSITION_INDEPENDENT_CODE})
232+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
219233
set(PTHREADPOOL_SOURCE_DIR "backends/xnnpack/third-party/pthreadpool")
220234
set(PTHREADPOOL_BUILD_TESTS
221235
OFF
@@ -235,6 +249,8 @@ if(EXECUTORCH_BUILD_PTHREADPOOL)
235249
CACHE STRING "")
236250
endif()
237251
add_subdirectory("${PTHREADPOOL_SOURCE_DIR}")
252+
set(CMAKE_POSITION_INDEPENDENT_CODE
253+
${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG})
238254
endif()
239255

240256
if(NOT PYTHON_EXECUTABLE)
@@ -546,6 +562,9 @@ if(EXECUTORCH_BUILD_PYBIND)
546562
list(APPEND _dep_libs custom_ops)
547563
endif()
548564

565+
if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT)
566+
list(APPEND _dep_libs custom_ops_aot_lib)
567+
endif()
549568
# compile options for pybind
550569

551570
set(_pybind_compile_options -Wno-deprecated-declarations -fPIC -frtti
@@ -559,7 +578,7 @@ if(EXECUTORCH_BUILD_PYBIND)
559578
target_include_directories(util PUBLIC ${_common_include_directories}
560579
${TORCH_INCLUDE_DIRS})
561580
target_compile_options(util PUBLIC ${_pybind_compile_options})
562-
target_link_libraries(util PRIVATE torch c10 executorch)
581+
target_link_libraries(util PRIVATE torch c10 executorch_no_prim_ops)
563582

564583
# pybind portable_lib
565584
pybind11_add_module(portable_lib extension/pybindings/pybindings.cpp)

examples/models/llama2/TARGETS

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ runtime.python_library(
1818
],
1919
deps = [
2020
"//caffe2:torch",
21-
"//executorch/examples/models/llama2/custom_ops:llama_custom_ops_aot_lib",
21+
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py",
2222
],
2323
)
2424

@@ -52,6 +52,7 @@ runtime.python_binary(
5252
main_module = "executorch.examples.models.llama2.export_llama",
5353
# visibility = ["//executorch/examples/..."],
5454
preload_deps = [
55+
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_lib",
5556
"//executorch/kernels/quantized:aot_lib",
5657
],
5758
deps = [

examples/models/llama2/custom_ops/CMakeLists.txt

+18-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ if(NOT TORCH_ROOT)
2525
set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch)
2626
endif()
2727

28-
set(_common_compile_options -Wno-deprecated-declarations)
28+
set(_common_compile_options -Wno-deprecated-declarations -fPIC)
2929

3030
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
3131
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
@@ -44,7 +44,7 @@ include(${EXECUTORCH_SRCS_FILE})
4444
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
4545

4646
# Custom op libraries
47-
set(custom_ops_libs extension_module)
47+
set(custom_ops_libs executorch_no_prim_ops)
4848
list(APPEND custom_ops_libs pthreadpool)
4949
list(APPEND custom_ops_libs cpuinfo)
5050
list(APPEND custom_ops_libs cpublas)
@@ -76,3 +76,19 @@ target_compile_options(custom_ops PUBLIC ${_common_compile_options}
7676
-DET_USE_THREADPOOL)
7777

7878
install(TARGETS custom_ops DESTINATION lib)
79+
80+
if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT)
81+
# Add a AOT library
82+
find_package(Torch CONFIG REQUIRED)
83+
add_library(custom_ops_aot_lib SHARED
84+
${CMAKE_CURRENT_SOURCE_DIR}/op_sdpa_aot.cpp)
85+
target_include_directories(custom_ops_aot_lib
86+
PUBLIC "${_common_include_directories}")
87+
target_include_directories(
88+
custom_ops_aot_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../../include")
89+
target_link_libraries(custom_ops_aot_lib PUBLIC custom_ops torch)
90+
target_compile_options(custom_ops_aot_lib PUBLIC -Wno-deprecated-declarations
91+
-fPIC -frtti -fexceptions)
92+
93+
install(TARGETS custom_ops_aot_lib DESTINATION lib)
94+
endif()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h>
10+
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
11+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
12+
13+
#include <torch/library.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
18+
namespace native {
19+
20+
Tensor& sdpa_with_kv_cache_out_no_context(
21+
const Tensor& q_projected,
22+
const Tensor& k_projected,
23+
const Tensor& v_projected,
24+
Tensor& key_cache,
25+
Tensor& value_cache,
26+
const int64_t start_pos,
27+
const int64_t seq_len,
28+
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
29+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
30+
const optional<Tensor> attn_mask,
31+
const double dropout_p,
32+
const bool is_causal,
33+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
34+
const optional<double> scale,
35+
Tensor& output) {
36+
exec_aten::RuntimeContext context{};
37+
return torch::executor::native::sdpa_with_kv_cache_out(
38+
context,
39+
q_projected,
40+
k_projected,
41+
v_projected,
42+
key_cache,
43+
value_cache,
44+
start_pos,
45+
seq_len,
46+
attn_mask,
47+
dropout_p,
48+
is_causal,
49+
scale,
50+
output);
51+
}
52+
53+
at::Tensor sdpa_with_kv_cache_aten(
54+
const at::Tensor& q_projected,
55+
const at::Tensor& k_projected,
56+
const at::Tensor& v_projected,
57+
at::Tensor& key_cache,
58+
at::Tensor& value_cache,
59+
const int64_t start_pos,
60+
const int64_t seq_len,
61+
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
62+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
63+
const c10::optional<at::Tensor> attn_mask,
64+
const double dropout_p,
65+
const bool is_causal,
66+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
67+
const c10::optional<double> scale) {
68+
auto output = at::empty_like(q_projected);
69+
WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11)
70+
(q_projected,
71+
k_projected,
72+
v_projected,
73+
key_cache,
74+
value_cache,
75+
start_pos,
76+
seq_len,
77+
attn_mask,
78+
dropout_p,
79+
is_causal,
80+
scale,
81+
output);
82+
return output;
83+
}
84+
85+
} // namespace native
86+
} // namespace executor
87+
} // namespace torch
88+
89+
TORCH_LIBRARY(llama, m) {
90+
m.def(
91+
"sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
92+
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
93+
"float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor");
94+
m.def(
95+
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
96+
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
97+
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)");
98+
}
99+
100+
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
101+
m.impl(
102+
"sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten);
103+
m.impl(
104+
"sdpa_with_kv_cache.out",
105+
WRAP_TO_ATEN(
106+
torch::executor::native::sdpa_with_kv_cache_out_no_context, 11));
107+
}

examples/models/llama2/custom_ops/sdpa_with_kv_cache.py

+20-91
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,29 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# Import custom op defined in op_sdpa_aot.cpp. Those ops are using PyTorch
8+
# C++ APIs for registration so here we need to import the shared library.
9+
# This is only needed for OSS.
10+
11+
import logging
12+
from pathlib import Path
13+
714
import torch
8-
from torch.library import impl, impl_abstract
915

10-
custom_ops_lib = torch.library.Library("llama", "DEF")
11-
custom_ops_lib.define(
12-
"sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
13-
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
14-
"float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor"
15-
)
16+
from torch.library import impl
1617

17-
custom_ops_lib.define(
18-
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
19-
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
20-
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)"
21-
)
18+
try:
19+
op = torch.ops.llama.sdpa_with_kv_cache.default
20+
assert op is not None
21+
except:
22+
libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*"))
23+
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
24+
logging.info(f"Loading custom ops library: {libs[0]}")
25+
torch.ops.load_library(libs[0])
26+
op = torch.ops.llama.sdpa_with_kv_cache.default
27+
assert op is not None
28+
29+
custom_ops_lib = torch.library.Library("llama", "IMPL")
2230

2331

2432
def _validate_params(
@@ -118,82 +126,3 @@ def sdpa_with_kv_cache_meta(
118126
)
119127

120128
return torch.empty_like(query)
121-
122-
123-
@impl(custom_ops_lib, "sdpa_with_kv_cache", "CompositeExplicitAutograd")
124-
def sdpa_with_kv_cache(
125-
query,
126-
key,
127-
value,
128-
key_cache,
129-
value_cache,
130-
start_pos,
131-
seq_len,
132-
attn_mask=None,
133-
drpout_p=0.0,
134-
is_causal=False,
135-
scale=None,
136-
):
137-
_validate_params(
138-
query,
139-
key,
140-
value,
141-
key_cache,
142-
value_cache,
143-
start_pos,
144-
seq_len,
145-
attn_mask,
146-
drpout_p,
147-
is_causal,
148-
scale,
149-
)
150-
151-
if attn_mask is not None:
152-
attn_mask = attn_mask[start_pos].view((1, -1))
153-
attn_mask = attn_mask[:, : start_pos + seq_len]
154-
q = query.transpose(1, 2)
155-
key_cache[:, start_pos] = key
156-
value_cache[:, start_pos] = value
157-
158-
sliced_k_cache = key_cache
159-
sliced_v_cache = value_cache
160-
sliced_k_cache = sliced_k_cache[:, : start_pos + seq_len, :, :]
161-
sliced_v_cache = sliced_v_cache[:, : start_pos + seq_len, :, :]
162-
sliced_k_cache = sliced_k_cache.transpose(1, 2)
163-
sliced_v_cache = sliced_v_cache.transpose(1, 2)
164-
out = torch.nn.functional.scaled_dot_product_attention(
165-
q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask
166-
)
167-
out = out.transpose(1, 2)
168-
return out
169-
170-
171-
@impl_abstract("llama::sdpa_with_kv_cache.out")
172-
def sdpa_with_kv_cache_out(
173-
query,
174-
key,
175-
value,
176-
key_cache,
177-
value_cache,
178-
start_pos,
179-
seq_len,
180-
attn_mask,
181-
drpout_p,
182-
is_causal,
183-
scale,
184-
out,
185-
):
186-
out = sdpa_with_kv_cache_meta(
187-
query,
188-
key,
189-
value,
190-
key_cache,
191-
value_cache,
192-
start_pos,
193-
seq_len,
194-
attn_mask,
195-
drpout_p,
196-
is_causal,
197-
scale,
198-
)
199-
return out

0 commit comments

Comments
 (0)