Skip to content

Commit 6e43135

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Use new API to register custom ops for llama model (#2916)
Summary: Pull Request resolved: #2916 Retry of D55713944 Use `EXECUTORCH_LIBRARY` to register custom kernel to ExecuTorch runtime. Reviewed By: lucylq Differential Revision: D55856491 fbshipit-source-id: 0e17ea18a7cd0b0b45a8e56e9d09ab5e2f8eb95e
1 parent e641ffc commit 6e43135

10 files changed

+96
-120
lines changed

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ if(EXECUTORCH_BUILD_PYBIND)
543543
endif()
544544

545545
if(EXECUTORCH_BUILD_CUSTOM)
546-
list(APPEND _dep_libs custom_ops_lib)
546+
list(APPEND _dep_libs custom_ops)
547547
endif()
548548

549549
# compile options for pybind

examples/models/llama2/CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ else()
106106
endif()
107107

108108
if(EXECUTORCH_BUILD_CUSTOM)
109-
target_link_options_shared_lib(custom_ops_lib)
110-
list(APPEND link_libraries custom_ops_lib)
109+
target_link_options_shared_lib(custom_ops)
110+
list(APPEND link_libraries custom_ops)
111111
endif()
112112

113113
set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack)

examples/models/llama2/custom_ops/CMakeLists.txt

+3-13
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,6 @@ list(APPEND custom_ops_libs cpuinfo)
5050
list(APPEND custom_ops_libs cpublas)
5151
list(APPEND custom_ops_libs eigen_blas)
5252

53-
# Generate C++ bindings to register kernels into both PyTorch (for AOT) and
54-
# Executorch (for runtime). Here select all ops in optimized.yaml
55-
set(_yaml "${CMAKE_CURRENT_LIST_DIR}/custom_ops.yaml")
56-
gen_selected_ops("${_yaml}" "" "")
57-
58-
generate_bindings_for_kernels(FUNCTIONS_YAML
59-
${CMAKE_CURRENT_SOURCE_DIR}/custom_ops.yaml)
60-
message("Generated files ${gen_command_sources}")
61-
6253
list(TRANSFORM _custom_ops__srcs PREPEND "${EXECUTORCH_ROOT}/")
6354

6455
# TODO: Consider moving xnnpack/threadpool in a separate lib since it's now used
@@ -70,6 +61,8 @@ if(NOT EXECUTORCH_BUILD_XNNPACK)
7061
"${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool.cpp"
7162
"${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool_guard.cpp"
7263
)
64+
else()
65+
list(APPEND custom_ops_libs xnnpack_backend)
7366
endif()
7467

7568
add_library(custom_ops ${_custom_ops__srcs})
@@ -82,7 +75,4 @@ target_link_libraries(custom_ops PUBLIC ${custom_ops_libs})
8275
target_compile_options(custom_ops PUBLIC ${_common_compile_options}
8376
-DET_USE_THREADPOOL)
8477

85-
# Build a library for _custom_ops_srcs
86-
#
87-
# custom_ops_lib: Register optimized ops kernels into Executorch runtime
88-
gen_operators_lib("custom_ops_lib" KERNEL_LIBS custom_ops DEPS executorch)
78+
install(TARGETS custom_ops DESTINATION lib)

examples/models/llama2/custom_ops/__init__.py

Whitespace-only changes.

examples/models/llama2/custom_ops/custom_ops.yaml

-14
This file was deleted.

examples/models/llama2/custom_ops/op_sdpa.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/runtime/kernel/kernel_includes.h>
9+
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h>
1010

1111
#include <executorch/kernels/optimized/blas/CPUBlas.h>
1212
#include <executorch/kernels/optimized/vec/functional.h>
@@ -22,6 +22,7 @@
2222
#include <executorch/backends/xnnpack/threadpool/threadpool.h>
2323
#include <executorch/extension/parallel/thread_parallel.h>
2424
#endif
25+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
2526

2627
namespace torch {
2728
namespace executor {
@@ -843,3 +844,8 @@ Tensor& sdpa_with_kv_cache_out(
843844
} // namespace native
844845
} // namespace executor
845846
} // namespace torch
847+
848+
EXECUTORCH_LIBRARY(
849+
llama,
850+
"sdpa_with_kv_cache.out",
851+
torch::executor::native::sdpa_with_kv_cache_out);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
namespace native {
17+
18+
Tensor& sdpa_with_kv_cache_out(
19+
RuntimeContext& ctx,
20+
const Tensor& q_projected,
21+
const Tensor& k_projected,
22+
const Tensor& v_projected,
23+
Tensor& key_cache,
24+
Tensor& value_cache,
25+
const int64_t start_pos,
26+
const int64_t seq_len,
27+
const optional<Tensor>& attn_mask,
28+
const double dropout_p,
29+
const bool is_causal,
30+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
31+
const optional<double> scale,
32+
Tensor& output);
33+
34+
Tensor& flash_attention_kernel_out(
35+
RuntimeContext& ctx,
36+
const Tensor& query,
37+
const Tensor& key,
38+
const Tensor& value,
39+
const optional<Tensor>& attn_mask,
40+
const double dropout_p,
41+
const bool is_causal,
42+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
43+
const optional<double> scale,
44+
Tensor& output);
45+
46+
} // namespace native
47+
} // namespace executor
48+
} // namespace torch

examples/models/llama2/custom_ops/op_sdpa_test.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
#include <limits>
1010

11-
#include <executorch/examples/models/llama2/custom_ops/FunctionHeaderWrapper.h> // Declares the operator
11+
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h>
12+
1213
#include <executorch/kernels/test/TestUtil.h>
1314
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1415
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
@@ -28,7 +29,7 @@ exec_aten::Tensor op_scaled_dot_product_attention(
2829
exec_aten::optional<double> scale,
2930
exec_aten::Tensor& out) {
3031
exec_aten::RuntimeContext context{};
31-
return torch::executor::llama::sdpa_outf(
32+
return torch::executor::native::flash_attention_kernel_out(
3233
context, query, key, value, attn_mask, dropout_p, is_causal, scale, out);
3334
}
3435

examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#include <limits>
1010

11-
#include <executorch/examples/models/llama2/custom_ops/FunctionHeaderWrapper.h> // Declares the operator
11+
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h> // Declares the operator
1212
#include <executorch/kernels/test/TestUtil.h>
1313
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1414
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
@@ -32,7 +32,7 @@ exec_aten::Tensor op_sdpa_with_kv_cache(
3232
exec_aten::optional<double> scale,
3333
exec_aten::Tensor& out) {
3434
exec_aten::RuntimeContext context{};
35-
return torch::executor::llama::sdpa_with_kv_cache_outf(
35+
return torch::executor::native::sdpa_with_kv_cache_out(
3636
context,
3737
query,
3838
key,
+30-85
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,11 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2-
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib")
3-
load("@fbsource//xplat/executorch/kernels/test:util.bzl", "codegen_function_header_wrapper")
4-
5-
def define_tests():
6-
codegen_function_header_wrapper("executorch/examples/models/llama2/custom_ops", "custom_ops")
7-
8-
# In the long run we should really have aten variant available as well
9-
deps = [":function_header_wrapper_custom_ops"]
10-
generated_lib_and_op_deps = [
11-
":custom_ops",
12-
":sdpa",
13-
":custom_ops_headers",
14-
]
15-
runtime.cxx_test(
16-
name = "op_sdpa_test",
17-
srcs = [
18-
"op_sdpa_test.cpp",
19-
],
20-
visibility = ["//executorch/..."],
21-
deps = [
22-
"//executorch/runtime/core/exec_aten:lib",
23-
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
24-
"//executorch/kernels/test:test_util",
25-
] + generated_lib_and_op_deps + deps,
26-
)
27-
runtime.cxx_test(
28-
name = "op_sdpa_with_kv_cache_test",
29-
srcs = [
30-
"op_sdpa_with_kv_cache_test.cpp",
31-
],
32-
visibility = ["//executorch/..."],
33-
deps = [
34-
"//executorch/runtime/core/exec_aten:lib",
35-
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
36-
"//executorch/kernels/test:test_util",
37-
] + generated_lib_and_op_deps + deps,
38-
)
392

403
def define_common_targets():
414
"""Defines targets that should be shared between fbcode and xplat.
425
436
The directory containing this targets.bzl file should also contain both
447
TARGETS and BUCK files that call this function.
458
"""
46-
479
runtime.python_library(
4810
name = "llama_custom_ops_aot_lib",
4911
srcs = [
@@ -58,71 +20,54 @@ def define_common_targets():
5820
],
5921
)
6022

61-
runtime.export_file(
62-
name = "custom_ops.yaml",
63-
visibility = [
64-
"//executorch/...",
65-
"@EXECUTORCH_CLIENTS",
66-
],
67-
)
68-
69-
# ~~~ START of custom ops 1 `my_ops::mul3` library definitions ~~~
70-
et_operator_library(
71-
name = "sdpa_op",
72-
ops = [
73-
"llama::sdpa.out",
74-
],
75-
define_static_targets = True,
76-
visibility = [
77-
"//executorch/codegen/...",
78-
"@EXECUTORCH_CLIENTS",
79-
],
80-
)
81-
82-
et_operator_library(
83-
name = "sdpa_with_kv_cache",
84-
ops = [
85-
"llama::sdpa_with_kv_cache.out",
86-
],
87-
define_static_targets = True,
88-
visibility = [
89-
"//executorch/codegen/...",
90-
"@EXECUTORCH_CLIENTS",
91-
],
92-
)
93-
9423
runtime.cxx_library(
95-
name = "sdpa",
24+
name = "custom_ops",
9625
srcs = ["op_sdpa.cpp"],
97-
deps = [
26+
exported_headers = ["op_sdpa.h"],
27+
exported_deps = [
9828
"//executorch/runtime/kernel:kernel_includes",
9929
"//executorch/kernels/portable/cpu:scalar_utils",
10030
"//executorch/kernels/optimized:libblas",
10131
"//executorch/kernels/optimized:libvec",
32+
"//executorch/extension/kernel_util:kernel_util",
10233
"//executorch/extension/parallel:thread_parallel",
10334
"//executorch/backends/xnnpack/threadpool:threadpool",
10435
],
105-
compiler_flags = ["-Wno-missing-prototypes"],
36+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
10637
visibility = [
10738
"//executorch/...",
10839
"//executorch/examples/models/llama2/custom_ops/...",
10940
"@EXECUTORCH_CLIENTS",
11041
],
42+
# @lint-ignore BUCKLINT link_whole
43+
link_whole = True,
11144
force_static = True,
11245
)
11346

114-
executorch_generated_lib(
115-
name = "custom_ops",
47+
runtime.cxx_test(
48+
name = "op_sdpa_test",
49+
srcs = [
50+
"op_sdpa_test.cpp",
51+
],
52+
visibility = ["//executorch/..."],
11653
deps = [
117-
":sdpa_op",
118-
":sdpa_with_kv_cache",
119-
":sdpa",
54+
"//executorch/runtime/core/exec_aten:lib",
55+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
56+
"//executorch/kernels/test:test_util",
57+
":custom_ops",
12058
],
121-
custom_ops_yaml_target = ":custom_ops.yaml",
122-
visibility = [
123-
"//executorch/...",
124-
"@EXECUTORCH_CLIENTS",
59+
)
60+
61+
runtime.cxx_test(
62+
name = "op_sdpa_with_kv_cache_test",
63+
srcs = [
64+
"op_sdpa_with_kv_cache_test.cpp",
65+
],
66+
visibility = ["//executorch/..."],
67+
deps = [
68+
"//executorch/runtime/core/exec_aten:lib",
69+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
70+
"//executorch/kernels/test:test_util",
71+
":custom_ops",
12572
],
126-
define_static_targets = True,
12773
)
128-
define_tests()

0 commit comments

Comments
 (0)