Skip to content

Reuse GELU implementation from PyTorch core #8322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .ci/scripts/build_llama_android.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ set -exu
# shellcheck source=/dev/null
source "$(dirname "${BASH_SOURCE[0]}")/utils.sh"

if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
PYTHON_EXECUTABLE=python3
fi
which "${PYTHON_EXECUTABLE}"
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"

install_executorch_and_backend_lib() {
echo "Installing executorch and xnnpack backend"
clean_executorch_install_folders
Expand All @@ -22,6 +28,7 @@ install_executorch_and_backend_lib() {
-DANDROID_ABI="${ANDROID_ABI}" \
-DCMAKE_INSTALL_PREFIX=cmake-android-out \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
Expand All @@ -47,6 +54,7 @@ build_llama_runner() {
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
-Bcmake-android-out/examples/models/llama examples/models/llama

cmake --build cmake-android-out/examples/models/llama -j4 --config Release
Expand Down
1 change: 1 addition & 0 deletions .ci/scripts/test_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ cmake_install_executorch_libraries() {
rm -rf cmake-out
retry cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')" \
-DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
Expand Down
10 changes: 7 additions & 3 deletions .ci/scripts/test_llava.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ fi
NPROC=8
if hash nproc &> /dev/null; then NPROC=$(nproc); fi

python_lib=$($PYTHON_EXECUTABLE -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
EXECUTORCH_COMMON_CMAKE_ARGS=" \
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
-DEXECUTORCH_ENABLE_LOGGING=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
Expand All @@ -46,6 +48,7 @@ EXECUTORCH_COMMON_CMAKE_ARGS=" \
cmake_install_executorch_libraries() {
cmake \
${EXECUTORCH_COMMON_CMAKE_ARGS} \
"-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" \
-B${BUILD_DIR} .

cmake --build ${BUILD_DIR} -j${NPROC} --target install --config ${CMAKE_BUILD_TYPE}
Expand All @@ -56,6 +59,7 @@ cmake_install_executorch_libraries_for_android() {
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
-DANDROID_ABI=arm64-v8a \
${EXECUTORCH_COMMON_CMAKE_ARGS} \
"-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" \
-B${BUILD_DIR} .

cmake --build ${BUILD_DIR} -j${NPROC} --target install --config ${CMAKE_BUILD_TYPE}
Expand All @@ -76,7 +80,7 @@ cmake_build_llava_runner() {

cmake \
${LLAVA_COMMON_CMAKE_ARGS} \
-DCMAKE_PREFIX_PATH="$python_lib" \
-DCMAKE_PREFIX_PATH="$python_lib;${CMAKE_PREFIX_PATH}" \
-B${BUILD_DIR}/${dir} \
${dir}

Expand All @@ -92,7 +96,7 @@ cmake_build_llava_runner_for_android() {
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
-DANDROID_ABI=arm64-v8a \
${LLAVA_COMMON_CMAKE_ARGS} \
-DCMAKE_PREFIX_PATH="$python_lib" \
-DCMAKE_PREFIX_PATH="$python_lib;${CMAKE_PREFIX_PATH}" \
-DLLAVA_RUNNER_NO_TORCH_DUMMY_IMAGE=ON \
-B${BUILD_DIR}/${dir} \
${dir}
Expand Down
5 changes: 3 additions & 2 deletions .ci/scripts/test_model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ prepare_artifacts_upload() {

build_cmake_executor_runner() {
echo "Building executor_runner"
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
rm -rf ${CMAKE_OUTPUT_DIR}
cmake -DCMAKE_BUILD_TYPE=Debug \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
-DCMAKE_PREFIX_PATH="$CMAKE_PREFIX_PATH" \
-B${CMAKE_OUTPUT_DIR} .

cmake --build ${CMAKE_OUTPUT_DIR} -j4 --config Debug
Expand Down Expand Up @@ -98,8 +100,7 @@ test_model() {

build_cmake_xnn_executor_runner() {
echo "Building xnn_executor_runner"
SITE_PACKAGES="$(${PYTHON_EXECUTABLE} -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
CMAKE_PREFIX_PATH="${SITE_PACKAGES}/torch"
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"

(rm -rf ${CMAKE_OUTPUT_DIR} \
&& mkdir ${CMAKE_OUTPUT_DIR} \
Expand Down
4 changes: 4 additions & 0 deletions .ci/scripts/test_phi_3_mini.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ NPROC=8
if hash nproc &> /dev/null; then NPROC=$(nproc); fi

cmake_install_executorch_libraries() {
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
cmake -DPYTHON_EXECUTABLE=python \
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
-DEXECUTORCH_ENABLE_LOGGING=1 \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
Expand All @@ -39,8 +41,10 @@ cmake_install_executorch_libraries() {
}

cmake_build_phi_3_mini() {
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
cmake -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
Expand Down
1 change: 1 addition & 0 deletions .ci/scripts/utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ cmake_install_executorch_lib() {
clean_executorch_install_folders
retry cmake -DBUCK2="$BUCK" \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_PREFIX_PATH="$($PYTHON_EXECUTABLE -c 'import torch as _; print(_.__path__[0])')" \
-DCMAKE_BUILD_TYPE=Release \
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
-Bcmake-out .
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ jobs:
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
conda activate "${CONDA_ENV}"

source .ci/scripts/utils.sh
install_executorch "use-pt-pinned-commit"
BUILD_TOOL="cmake"
PYTHON_EXECUTABLE=python \
bash .ci/scripts/build_llama_android.sh "${BUILD_TOOL}"
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ jobs:
rm -rf cmake-out
cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_PREFIX_PATH="$(python -c 'import torch as _; print(_.__path__[0])')" \
-DCMAKE_BUILD_TYPE=Release \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
Expand All @@ -411,6 +412,7 @@ jobs:
cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_PREFIX_PATH="$(python -c 'import torch as _; print(_.__path__[0])')" \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,8 @@ if(BUILD_EXECUTORCH_PORTABLE_OPS)
endif()

if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED)
# find pytorch lib here to make it available to all sub-directories
find_package_torch_headers()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels/optimized)
endif()

Expand Down
19 changes: 19 additions & 0 deletions build/Utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,22 @@ function(resolve_python_executable)
)
endif()
endfunction()

# find_package(Torch CONFIG REQUIRED) replacement for targets that
# have a header-only Torch dependency. Because find_package sets
# variables in the parent scope, we use a macro to preserve this
# rather than maintaining our own list of those variables.
macro(find_package_torch_headers)
# We cannot simply use CMAKE_FIND_ROOT_PATH_BOTH, because that does
# not propagate into TorchConfig.cmake.
foreach(mode_kind IN ITEMS PACKAGE LIBRARY INCLUDE)
set(OLD_CMAKE_FIND_ROOT_PATH_MODE_${mode_kind} ${CMAKE_FIND_ROOT_PATH_MODE_${mode_kind}})
set(CMAKE_FIND_ROOT_PATH_MODE_${mode_kind} BOTH)
endforeach()
if(NOT TARGET torch)
find_package(Torch CONFIG REQUIRED)
endif()
foreach(mode_kind IN ITEMS PACKAGE LIBRARY INCLUDE)
set(CMAKE_FIND_ROOT_PATH_MODE_${mode_kind} ${OLD_CMAKE_FIND_ROOT_PATH_MODE_${mode_kind}})
endforeach()
endmacro()
8 changes: 8 additions & 0 deletions build/build_android_llm_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

set -ex

if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
PYTHON_EXECUTABLE=python3
fi
which "${PYTHON_EXECUTABLE}"
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"

build_jar() {
pushd extension/android
./gradlew build
Expand Down Expand Up @@ -36,6 +42,7 @@ build_android_native_library() {
fi

cmake . -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \
-DANDROID_ABI="${ANDROID_ABI}" \
-DANDROID_PLATFORM=android-26 \
Expand Down Expand Up @@ -69,6 +76,7 @@ build_android_native_library() {
-DANDROID_ABI="${ANDROID_ABI}" \
-DANDROID_PLATFORM=android-26 \
-DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
-DEXECUTORCH_ENABLE_LOGGING=ON \
-DEXECUTORCH_LOG_LEVEL=Info \
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \
Expand Down
2 changes: 1 addition & 1 deletion kernels/optimized/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ message("Generated files ${gen_command_sources}")

list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/")
add_library(optimized_kernels ${_optimized_kernels__srcs})
target_include_directories(optimized_kernels PRIVATE "${EXECUTORCH_ROOT}/third-party/pocketfft")
target_include_directories(optimized_kernels PRIVATE ${TORCH_INCLUDE_DIRS} "${EXECUTORCH_ROOT}/third-party/pocketfft")
target_link_libraries(
optimized_kernels PRIVATE executorch_core cpublas extension_threadpool
)
Expand Down
51 changes: 15 additions & 36 deletions kernels/optimized/cpu/op_gelu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <cmath>

#include <ATen/native/cpu/Gelu.h>
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>
Expand Down Expand Up @@ -47,48 +48,26 @@ void gelu(
CTYPE* out_data = output.mutable_data_ptr<CTYPE>();
size_t lim = input.numel();

// TODO: Add fast path for tanh using sleef's tanh
if (approximate == "tanh") {
// 0.5 * x * (1 + Tanh(sqrt(2 / pi) * (x + 0.044715 * x^3))
for (size_t i = 0; i < lim; ++i) {
const CTYPE x = in_data[i];
const CTYPE kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
const CTYPE kKappa = 0.044715;
auto x_cube = x * x * x;
auto inner = kBeta * (x + kKappa * x_cube);
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::tanh(inner));
using Vec = at::vec::Vectorized<CTYPE>;
int i = 0;
for (; i < lim - (lim % Vec::size()); i += Vec::size()) {
Vec x = Vec::loadu(in_data + i);
at::native::vectorized_gelu_approximated_with_tanh(x).store(out_data + i);
}
} else if (approximate == "none") { // dont appx
// GELU(x) = x * Φ(x) where Φ(x) is the is the Cumulative Distribution
// Function for Gaussian Distribution.

#ifndef __aarch64__
for (size_t i = 0; i < lim; ++i) {
const CTYPE x = in_data[i];
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
for (; i < lim; ++i) {
out_data[i] = at::native::scalar_gelu_approximated_with_tanh(in_data[i]);
}
#else
size_t i = 0;
if constexpr (std::is_same_v<CTYPE, float>) {
for (; i + 4 < lim; i += 4) {
const float32x4_t in =
vld1q_f32(static_cast<const float*>(&in_data[i]));
const float32x4_t m_sqrt1_2x4 = {
M_SQRT1_2, M_SQRT1_2, M_SQRT1_2, M_SQRT1_2};
const float32x4_t ones = vmovq_n_f32(1.0);
const float32x4_t halves = vmovq_n_f32(0.5);
float32x4_t out = Sleef_erff4_u10(vmulq_f32(in, m_sqrt1_2x4));
vst1q_f32(
static_cast<float*>(&out_data[i]),
vmulq_f32(vmulq_f32(vaddq_f32(out, ones), in), halves));
}
} else if (approximate == "none") {
using Vec = at::vec::Vectorized<CTYPE>;
int i = 0;
for (; i < lim - (lim % Vec::size()); i += Vec::size()) {
Vec x = Vec::loadu(in_data + i);
at::native::vectorized_gelu(x).store(out_data + i);
}
for (; i < lim; ++i) {
const CTYPE x = in_data[i];
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
out_data[i] = at::native::scalar_gelu(in_data[i]);
}
#endif // __aarch64__

} else {
ET_KERNEL_CHECK_MSG(
context,
Expand Down
15 changes: 9 additions & 6 deletions kernels/optimized/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@ _OPTIMIZED_ATEN_OPS = (
op_target(name = "op_sigmoid"),
op_target(
name = "op_gelu",
deps = select({
"DEFAULT": [],
"ovr_config//cpu:arm64": [
"fbsource//third-party/sleef:sleef_arm",
],
}) + [
deps = [
"//executorch/kernels/portable/cpu/util:activation_ops_util",
"//executorch/runtime/core/portable_type/c10:aten_headers_for_executorch",
],
),
op_target(
Expand Down Expand Up @@ -100,6 +96,13 @@ _OPTIMIZED_ATEN_OPS = (
),
)


def get_sleef_preprocessor_flags():
if runtime.is_oss:
return []
return ["-DAT_BUILD_ARM_VEC256_WITH_SLEEF"]


def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.

Expand Down
11 changes: 8 additions & 3 deletions kernels/optimized/op_registration_util.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,14 @@ def define_op_library(name, deps):
"//executorch/kernels/test/...",
"@EXECUTORCH_CLIENTS",
],
# kernels often have helpers with no prototypes just disabling the warning here as the headers
# are codegend and linked in later
compiler_flags = ["-Wno-missing-prototypes"] + get_compiler_optimization_flags(),
compiler_flags = [
# kernels often have helpers with no prototypes just disabling the warning here as the headers
# are codegend and linked in later
"-Wno-missing-prototypes",
# pragma unroll fails with -Os, don't need to warn us and
# fail Werror builds; see https://godbolt.org/z/zvf85vTsr
"-Wno-pass-failed",
] + get_compiler_optimization_flags(),
deps = [
"//executorch/runtime/kernel:kernel_includes",
] + augmented_deps + get_vec_deps(),
Expand Down
9 changes: 7 additions & 2 deletions kernels/optimized/optimized-oss.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This yaml file contains operators that have optimized kernels available.
# Note that this is a copy of optimized.yaml that does not include gelu and
# log_softmax, due to the OSS build not currently including sleef.
# Note that this is a copy of optimized.yaml that does not include log_softmax,
# due to the OSS build not currently including sleef.
# TODO (T183193812)

- op: _fft_r2c.out
Expand Down Expand Up @@ -45,6 +45,11 @@
- arg_meta: null
kernel_name: torch::executor::opt_sigmoid_out

- op: gelu.out
kernels:
- arg_meta: null
kernel_name: torch::executor::opt_gelu_out

- op: le.Scalar_out
kernels:
- arg_meta: null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,5 @@ def define_op_target(name, deps):

def is_op_disabled(name):
# TODO (gjcomer) Enable ops with sleef dependency in OSS
disabled_ops = ["op_gelu", "op_log_softmax"]
disabled_ops = ["op_log_softmax"]
return name in disabled_ops
Loading
Loading