Skip to content

Commit 6199538

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Use new API to register custom ops for llama model (#2916)
Summary: Retry of D55713944 Use `EXECUTORCH_LIBRARY` to register custom kernel to ExecuTorch runtime. Differential Revision: D55856491
1 parent f80150f commit 6199538

File tree

11 files changed

+98
-126
lines changed

11 files changed

+98
-126
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ if(EXECUTORCH_BUILD_PYBIND)
543543
endif()
544544

545545
if(EXECUTORCH_BUILD_CUSTOM)
546-
list(APPEND _dep_libs custom_ops_lib)
546+
list(APPEND _dep_libs custom_ops)
547547
endif()
548548

549549
# compile options for pybind

examples/models/llama2/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ else()
9292
endif()
9393

9494
if(EXECUTORCH_BUILD_CUSTOM)
95-
target_link_options_shared_lib(custom_ops_lib)
96-
list(APPEND link_libraries custom_ops_lib)
95+
target_link_options_shared_lib(custom_ops)
96+
list(APPEND link_libraries custom_ops)
9797
endif()
9898

9999
# XNNPACK pthreadpool cpuinfo

examples/models/llama2/custom_ops/CMakeLists.txt

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,6 @@ list(APPEND custom_ops_libs cpuinfo)
5050
list(APPEND custom_ops_libs cpublas)
5151
list(APPEND custom_ops_libs eigen_blas)
5252

53-
# Generate C++ bindings to register kernels into both PyTorch (for AOT) and
54-
# Executorch (for runtime). Here select all ops in optimized.yaml
55-
set(_yaml "${CMAKE_CURRENT_LIST_DIR}/custom_ops.yaml")
56-
gen_selected_ops("${_yaml}" "" "")
57-
58-
generate_bindings_for_kernels(FUNCTIONS_YAML
59-
${CMAKE_CURRENT_SOURCE_DIR}/custom_ops.yaml)
60-
message("Generated files ${gen_command_sources}")
61-
6253
list(TRANSFORM _custom_ops__srcs PREPEND "${EXECUTORCH_ROOT}/")
6354

6455
# TODO: Consider moving xnnpack/threadpool in a separate lib since it's now used
@@ -70,6 +61,8 @@ if(NOT EXECUTORCH_BUILD_XNNPACK)
7061
"${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool.cpp"
7162
"${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool_guard.cpp"
7263
)
64+
else()
65+
list(APPEND custom_ops_libs xnnpack_backend)
7366
endif()
7467

7568
add_library(custom_ops ${_custom_ops__srcs})
@@ -82,7 +75,4 @@ target_link_libraries(custom_ops PUBLIC ${custom_ops_libs})
8275
target_compile_options(custom_ops PUBLIC ${_common_compile_options}
8376
-DET_USE_THREADPOOL)
8477

85-
# Build a library for _custom_ops_srcs
86-
#
87-
# custom_ops_lib: Register optimized ops kernels into Executorch runtime
88-
gen_operators_lib("custom_ops_lib" KERNEL_LIBS custom_ops DEPS executorch)
78+
install(TARGETS custom_ops DESTINATION lib)

examples/models/llama2/custom_ops/__init__.py

Whitespace-only changes.

examples/models/llama2/custom_ops/custom_ops.yaml

Lines changed: 0 additions & 14 deletions
This file was deleted.

examples/models/llama2/custom_ops/op_sdpa.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/runtime/kernel/kernel_includes.h>
9+
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h>
1010

1111
#include <executorch/kernels/optimized/blas/CPUBlas.h>
1212
#include <executorch/kernels/optimized/vec/functional.h>
@@ -22,6 +22,7 @@
2222
#include <executorch/backends/xnnpack/threadpool/threadpool.h>
2323
#include <executorch/extension/parallel/thread_parallel.h>
2424
#endif
25+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
2526

2627
namespace torch {
2728
namespace executor {
@@ -843,3 +844,8 @@ Tensor& sdpa_with_kv_cache_out(
843844
} // namespace native
844845
} // namespace executor
845846
} // namespace torch
847+
848+
EXECUTORCH_LIBRARY(
849+
llama,
850+
"sdpa_with_kv_cache.out",
851+
torch::executor::native::sdpa_with_kv_cache_out);
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
namespace native {
17+
18+
Tensor& sdpa_with_kv_cache_out(
19+
RuntimeContext& ctx,
20+
const Tensor& q_projected,
21+
const Tensor& k_projected,
22+
const Tensor& v_projected,
23+
Tensor& key_cache,
24+
Tensor& value_cache,
25+
const int64_t start_pos,
26+
const int64_t seq_len,
27+
const optional<Tensor>& attn_mask,
28+
const double dropout_p,
29+
const bool is_causal,
30+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
31+
const optional<double> scale,
32+
Tensor& output);
33+
34+
Tensor& flash_attention_kernel_out(
35+
RuntimeContext& ctx,
36+
const Tensor& query,
37+
const Tensor& key,
38+
const Tensor& value,
39+
const optional<Tensor>& attn_mask,
40+
const double dropout_p,
41+
const bool is_causal,
42+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
43+
const optional<double> scale,
44+
Tensor& output);
45+
46+
} // namespace native
47+
} // namespace executor
48+
} // namespace torch

examples/models/llama2/custom_ops/op_sdpa_test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
#include <limits>
1010

11-
#include <executorch/examples/models/llama2/custom_ops/FunctionHeaderWrapper.h> // Declares the operator
11+
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h>
12+
1213
#include <executorch/kernels/test/TestUtil.h>
1314
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1415
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
@@ -28,7 +29,7 @@ exec_aten::Tensor op_scaled_dot_product_attention(
2829
exec_aten::optional<double> scale,
2930
exec_aten::Tensor& out) {
3031
exec_aten::RuntimeContext context{};
31-
return torch::executor::llama::sdpa_outf(
32+
return torch::executor::native::flash_attention_kernel_out(
3233
context, query, key, value, attn_mask, dropout_p, is_causal, scale, out);
3334
}
3435

examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#include <limits>
1010

11-
#include <executorch/examples/models/llama2/custom_ops/FunctionHeaderWrapper.h> // Declares the operator
11+
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h> // Declares the operator
1212
#include <executorch/kernels/test/TestUtil.h>
1313
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1414
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
@@ -32,7 +32,7 @@ exec_aten::Tensor op_sdpa_with_kv_cache(
3232
exec_aten::optional<double> scale,
3333
exec_aten::Tensor& out) {
3434
exec_aten::RuntimeContext context{};
35-
return torch::executor::llama::sdpa_with_kv_cache_outf(
35+
return torch::executor::native::sdpa_with_kv_cache_out(
3636
context,
3737
query,
3838
key,
Lines changed: 30 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,11 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2-
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib")
3-
load("@fbsource//xplat/executorch/kernels/test:util.bzl", "codegen_function_header_wrapper")
4-
5-
def define_tests():
6-
codegen_function_header_wrapper("executorch/examples/models/llama2/custom_ops", "custom_ops")
7-
8-
# In the long run we should really have aten variant available as well
9-
deps = [":function_header_wrapper_custom_ops"]
10-
generated_lib_and_op_deps = [
11-
":custom_ops",
12-
":sdpa",
13-
":custom_ops_headers",
14-
]
15-
runtime.cxx_test(
16-
name = "op_sdpa_test",
17-
srcs = [
18-
"op_sdpa_test.cpp",
19-
],
20-
visibility = ["//executorch/..."],
21-
deps = [
22-
"//executorch/runtime/core/exec_aten:lib",
23-
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
24-
"//executorch/kernels/test:test_util",
25-
] + generated_lib_and_op_deps + deps,
26-
)
27-
runtime.cxx_test(
28-
name = "op_sdpa_with_kv_cache_test",
29-
srcs = [
30-
"op_sdpa_with_kv_cache_test.cpp",
31-
],
32-
visibility = ["//executorch/..."],
33-
deps = [
34-
"//executorch/runtime/core/exec_aten:lib",
35-
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
36-
"//executorch/kernels/test:test_util",
37-
] + generated_lib_and_op_deps + deps,
38-
)
392

403
def define_common_targets():
414
"""Defines targets that should be shared between fbcode and xplat.
425
436
The directory containing this targets.bzl file should also contain both
447
TARGETS and BUCK files that call this function.
458
"""
46-
479
runtime.python_library(
4810
name = "llama_custom_ops_aot_lib",
4911
srcs = [
@@ -58,71 +20,54 @@ def define_common_targets():
5820
],
5921
)
6022

61-
runtime.export_file(
62-
name = "custom_ops.yaml",
63-
visibility = [
64-
"//executorch/...",
65-
"@EXECUTORCH_CLIENTS",
66-
],
67-
)
68-
69-
# ~~~ START of custom ops 1 `my_ops::mul3` library definitions ~~~
70-
et_operator_library(
71-
name = "sdpa_op",
72-
ops = [
73-
"llama::sdpa.out",
74-
],
75-
define_static_targets = True,
76-
visibility = [
77-
"//executorch/codegen/...",
78-
"@EXECUTORCH_CLIENTS",
79-
],
80-
)
81-
82-
et_operator_library(
83-
name = "sdpa_with_kv_cache",
84-
ops = [
85-
"llama::sdpa_with_kv_cache.out",
86-
],
87-
define_static_targets = True,
88-
visibility = [
89-
"//executorch/codegen/...",
90-
"@EXECUTORCH_CLIENTS",
91-
],
92-
)
93-
9423
runtime.cxx_library(
95-
name = "sdpa",
24+
name = "custom_ops",
9625
srcs = ["op_sdpa.cpp"],
97-
deps = [
26+
exported_headers = ["op_sdpa.h"],
27+
exported_deps = [
9828
"//executorch/runtime/kernel:kernel_includes",
9929
"//executorch/kernels/portable/cpu:scalar_utils",
10030
"//executorch/kernels/optimized:libblas",
10131
"//executorch/kernels/optimized:libvec",
32+
"//executorch/extension/kernel_util:kernel_util",
10233
"//executorch/extension/parallel:thread_parallel",
10334
"//executorch/backends/xnnpack/threadpool:threadpool",
10435
],
105-
compiler_flags = ["-Wno-missing-prototypes"],
36+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
10637
visibility = [
10738
"//executorch/...",
10839
"//executorch/examples/models/llama2/custom_ops/...",
10940
"@EXECUTORCH_CLIENTS",
11041
],
42+
# @lint-ignore BUCKLINT link_whole
43+
link_whole = True,
11144
force_static = True,
11245
)
11346

114-
executorch_generated_lib(
115-
name = "custom_ops",
47+
runtime.cxx_test(
48+
name = "op_sdpa_test",
49+
srcs = [
50+
"op_sdpa_test.cpp",
51+
],
52+
visibility = ["//executorch/..."],
11653
deps = [
117-
":sdpa_op",
118-
":sdpa_with_kv_cache",
119-
":sdpa",
54+
"//executorch/runtime/core/exec_aten:lib",
55+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
56+
"//executorch/kernels/test:test_util",
57+
":custom_ops",
12058
],
121-
custom_ops_yaml_target = ":custom_ops.yaml",
122-
visibility = [
123-
"//executorch/...",
124-
"@EXECUTORCH_CLIENTS",
59+
)
60+
61+
runtime.cxx_test(
62+
name = "op_sdpa_with_kv_cache_test",
63+
srcs = [
64+
"op_sdpa_with_kv_cache_test.cpp",
65+
],
66+
visibility = ["//executorch/..."],
67+
deps = [
68+
"//executorch/runtime/core/exec_aten:lib",
69+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
70+
"//executorch/kernels/test:test_util",
71+
":custom_ops",
12572
],
126-
define_static_targets = True,
12773
)
128-
define_tests()

extension/android/CMakeLists.txt

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,10 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
5858
add_library(llama_runner STATIC IMPORTED)
5959
set_property(TARGET llama_runner PROPERTY IMPORTED_LOCATION ${LLAMA_RUNNER_PATH})
6060

61-
set(CUSTOM_OPS_LIB_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama2/custom_ops/libcustom_ops_lib.a)
62-
add_library(custom_ops_lib STATIC IMPORTED)
63-
set_property(TARGET custom_ops_lib PROPERTY IMPORTED_LOCATION ${CUSTOM_OPS_LIB_PATH})
64-
6561
set(CUSTOM_OPS_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama2/custom_ops/libcustom_ops.a)
6662
add_library(custom_ops STATIC IMPORTED)
6763
set_property(TARGET custom_ops PROPERTY IMPORTED_LOCATION ${CUSTOM_OPS_PATH})
68-
target_link_options_shared_lib(custom_ops_lib)
64+
target_link_options_shared_lib(custom_ops)
6965

7066
if(TARGET pthreadpool)
7167
set(LLAMA_JNI_SRCS jni/jni_layer_llama.cpp ../../backends/xnnpack/threadpool/cpuinfo_utils.cpp)
@@ -82,6 +78,6 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
8278
endif()
8379
target_include_directories(executorch_llama_jni PRIVATE ${_common_include_directories})
8480
target_link_libraries(executorch_llama_jni ${link_libraries} llama_runner
85-
custom_ops custom_ops_lib cpublas eigen_blas)
81+
custom_ops cpublas eigen_blas)
8682
target_compile_options(executorch_llama_jni PUBLIC ${_common_compile_options})
8783
endif()

0 commit comments

Comments
 (0)