Skip to content

Use new API to register custom ops for llama model #2916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ if(EXECUTORCH_BUILD_PYBIND)
endif()

if(EXECUTORCH_BUILD_CUSTOM)
list(APPEND _dep_libs custom_ops_lib)
list(APPEND _dep_libs custom_ops)
endif()

# compile options for pybind
Expand Down
4 changes: 2 additions & 2 deletions examples/models/llama2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ else()
endif()

if(EXECUTORCH_BUILD_CUSTOM)
target_link_options_shared_lib(custom_ops_lib)
list(APPEND link_libraries custom_ops_lib)
target_link_options_shared_lib(custom_ops)
list(APPEND link_libraries custom_ops)
endif()

set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack)
Expand Down
16 changes: 3 additions & 13 deletions examples/models/llama2/custom_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,6 @@ list(APPEND custom_ops_libs cpuinfo)
list(APPEND custom_ops_libs cpublas)
list(APPEND custom_ops_libs eigen_blas)

# Generate C++ bindings to register kernels into both PyTorch (for AOT) and
# Executorch (for runtime). Here select all ops in optimized.yaml
set(_yaml "${CMAKE_CURRENT_LIST_DIR}/custom_ops.yaml")
gen_selected_ops("${_yaml}" "" "")

generate_bindings_for_kernels(FUNCTIONS_YAML
${CMAKE_CURRENT_SOURCE_DIR}/custom_ops.yaml)
message("Generated files ${gen_command_sources}")

list(TRANSFORM _custom_ops__srcs PREPEND "${EXECUTORCH_ROOT}/")

# TODO: Consider moving xnnpack/threadpool in a separate lib since it's now used
Expand All @@ -70,6 +61,8 @@ if(NOT EXECUTORCH_BUILD_XNNPACK)
"${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool_guard.cpp"
)
else()
list(APPEND custom_ops_libs xnnpack_backend)
endif()

add_library(custom_ops ${_custom_ops__srcs})
Expand All @@ -82,7 +75,4 @@ target_link_libraries(custom_ops PUBLIC ${custom_ops_libs})
target_compile_options(custom_ops PUBLIC ${_common_compile_options}
-DET_USE_THREADPOOL)

# Build a library for _custom_ops_srcs
#
# custom_ops_lib: Register optimized ops kernels into Executorch runtime
gen_operators_lib("custom_ops_lib" KERNEL_LIBS custom_ops DEPS executorch)
install(TARGETS custom_ops DESTINATION lib)
Empty file.
14 changes: 0 additions & 14 deletions examples/models/llama2/custom_ops/custom_ops.yaml

This file was deleted.

8 changes: 7 additions & 1 deletion examples/models/llama2/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h>

#include <executorch/kernels/optimized/blas/CPUBlas.h>
#include <executorch/kernels/optimized/vec/functional.h>
Expand All @@ -22,6 +22,7 @@
#include <executorch/backends/xnnpack/threadpool/threadpool.h>
#include <executorch/extension/parallel/thread_parallel.h>
#endif
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>

namespace torch {
namespace executor {
Expand Down Expand Up @@ -843,3 +844,8 @@ Tensor& sdpa_with_kv_cache_out(
} // namespace native
} // namespace executor
} // namespace torch

EXECUTORCH_LIBRARY(
llama,
"sdpa_with_kv_cache.out",
torch::executor::native::sdpa_with_kv_cache_out);
48 changes: 48 additions & 0 deletions examples/models/llama2/custom_ops/op_sdpa.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {

namespace native {

Tensor& sdpa_with_kv_cache_out(
RuntimeContext& ctx,
const Tensor& q_projected,
const Tensor& k_projected,
const Tensor& v_projected,
Tensor& key_cache,
Tensor& value_cache,
const int64_t start_pos,
const int64_t seq_len,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output);

Tensor& flash_attention_kernel_out(
RuntimeContext& ctx,
const Tensor& query,
const Tensor& key,
const Tensor& value,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output);

} // namespace native
} // namespace executor
} // namespace torch
5 changes: 3 additions & 2 deletions examples/models/llama2/custom_ops/op_sdpa_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

#include <limits>

#include <executorch/examples/models/llama2/custom_ops/FunctionHeaderWrapper.h> // Declares the operator
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h>

#include <executorch/kernels/test/TestUtil.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
Expand All @@ -28,7 +29,7 @@ exec_aten::Tensor op_scaled_dot_product_attention(
exec_aten::optional<double> scale,
exec_aten::Tensor& out) {
exec_aten::RuntimeContext context{};
return torch::executor::llama::sdpa_outf(
return torch::executor::native::flash_attention_kernel_out(
context, query, key, value, attn_mask, dropout_p, is_causal, scale, out);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include <limits>

#include <executorch/examples/models/llama2/custom_ops/FunctionHeaderWrapper.h> // Declares the operator
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h> // Declares the operator
#include <executorch/kernels/test/TestUtil.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
Expand All @@ -32,7 +32,7 @@ exec_aten::Tensor op_sdpa_with_kv_cache(
exec_aten::optional<double> scale,
exec_aten::Tensor& out) {
exec_aten::RuntimeContext context{};
return torch::executor::llama::sdpa_with_kv_cache_outf(
return torch::executor::native::sdpa_with_kv_cache_out(
context,
query,
key,
Expand Down
115 changes: 30 additions & 85 deletions examples/models/llama2/custom_ops/targets.bzl
Original file line number Diff line number Diff line change
@@ -1,49 +1,11 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib")
load("@fbsource//xplat/executorch/kernels/test:util.bzl", "codegen_function_header_wrapper")

def define_tests():
codegen_function_header_wrapper("executorch/examples/models/llama2/custom_ops", "custom_ops")

# In the long run we should really have aten variant available as well
deps = [":function_header_wrapper_custom_ops"]
generated_lib_and_op_deps = [
":custom_ops",
":sdpa",
":custom_ops_headers",
]
runtime.cxx_test(
name = "op_sdpa_test",
srcs = [
"op_sdpa_test.cpp",
],
visibility = ["//executorch/..."],
deps = [
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
"//executorch/kernels/test:test_util",
] + generated_lib_and_op_deps + deps,
)
runtime.cxx_test(
name = "op_sdpa_with_kv_cache_test",
srcs = [
"op_sdpa_with_kv_cache_test.cpp",
],
visibility = ["//executorch/..."],
deps = [
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
"//executorch/kernels/test:test_util",
] + generated_lib_and_op_deps + deps,
)

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.

The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""

runtime.python_library(
name = "llama_custom_ops_aot_lib",
srcs = [
Expand All @@ -58,71 +20,54 @@ def define_common_targets():
],
)

runtime.export_file(
name = "custom_ops.yaml",
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
)

# ~~~ START of custom ops 1 `my_ops::mul3` library definitions ~~~
et_operator_library(
name = "sdpa_op",
ops = [
"llama::sdpa.out",
],
define_static_targets = True,
visibility = [
"//executorch/codegen/...",
"@EXECUTORCH_CLIENTS",
],
)

et_operator_library(
name = "sdpa_with_kv_cache",
ops = [
"llama::sdpa_with_kv_cache.out",
],
define_static_targets = True,
visibility = [
"//executorch/codegen/...",
"@EXECUTORCH_CLIENTS",
],
)

runtime.cxx_library(
name = "sdpa",
name = "custom_ops",
srcs = ["op_sdpa.cpp"],
deps = [
exported_headers = ["op_sdpa.h"],
exported_deps = [
"//executorch/runtime/kernel:kernel_includes",
"//executorch/kernels/portable/cpu:scalar_utils",
"//executorch/kernels/optimized:libblas",
"//executorch/kernels/optimized:libvec",
"//executorch/extension/kernel_util:kernel_util",
"//executorch/extension/parallel:thread_parallel",
"//executorch/backends/xnnpack/threadpool:threadpool",
],
compiler_flags = ["-Wno-missing-prototypes"],
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
visibility = [
"//executorch/...",
"//executorch/examples/models/llama2/custom_ops/...",
"@EXECUTORCH_CLIENTS",
],
# @lint-ignore BUCKLINT link_whole
link_whole = True,
force_static = True,
)

executorch_generated_lib(
name = "custom_ops",
runtime.cxx_test(
name = "op_sdpa_test",
srcs = [
"op_sdpa_test.cpp",
],
visibility = ["//executorch/..."],
deps = [
":sdpa_op",
":sdpa_with_kv_cache",
":sdpa",
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
"//executorch/kernels/test:test_util",
":custom_ops",
],
custom_ops_yaml_target = ":custom_ops.yaml",
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
)

runtime.cxx_test(
name = "op_sdpa_with_kv_cache_test",
srcs = [
"op_sdpa_with_kv_cache_test.cpp",
],
visibility = ["//executorch/..."],
deps = [
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
"//executorch/kernels/test:test_util",
":custom_ops",
],
define_static_targets = True,
)
define_tests()
Loading