Skip to content

AQLM custom kernels for Android #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: v0.3.0_branch
Choose a base branch
from
Open

AQLM custom kernels for Android #1

wants to merge 15 commits into from

Conversation

BlackSamorez
Copy link
Owner

@BlackSamorez BlackSamorez commented Aug 7, 2024

This PR contains all the necessary modifications to Executorch v0.3.0 to run AQLM models on an Android device.

It is designed to be compatible with their Llama demo app build and deploy process.

@@ -467,6 +467,9 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
add_subdirectory(
${CMAKE_CURRENT_SOURCE_DIR}/examples/models/llama2/custom_ops
)
add_subdirectory(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instruct CMake to traverse the subdirectory containing the AQLM kernels. Keeping the directory in the same CMake tree makes it easier to link those cutom operators libs.

CMakeLists.txt Outdated
@@ -633,13 +636,15 @@ if(EXECUTORCH_BUILD_PYBIND)
# TODO(larryliu): Fix macOS 2 dylibs having 2 sets of static variables issue
if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT AND NOT APPLE)
list(APPEND _dep_libs custom_ops_aot_lib)
list(APPEND _dep_libs aqlm)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link the custom kernel to the portable_lib library. If custom operators are to be compiled that is.

@@ -633,13 +636,16 @@ if(EXECUTORCH_BUILD_PYBIND)
# TODO(larryliu): Fix macOS 2 dylibs having 2 sets of static variables issue
if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT AND NOT APPLE)
list(APPEND _dep_libs custom_ops_aot_lib)
list(APPEND _dep_libs aqlm_aot_lib)
Copy link
Owner Author

@BlackSamorez BlackSamorez Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add torch bindings for AQLM to the portable_lib.

endif()
# TODO(laryliu): Fix linux duplicate registation problem. In GH CI worker
# libcustom_ops.a doesn't dedup with the one indirectly linked from
# libcustom_ops_aot_lib.a
if(EXECUTORCH_BUILD_KERNELS_CUSTOM AND APPLE)
target_link_options_shared_lib(custom_ops)
list(APPEND _dep_libs custom_ops)
target_link_options_shared_lib(aqlm)
Copy link
Owner Author

@BlackSamorez BlackSamorez Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Force the linkage of core aqlm ops to portable_lib. This step is NECESSARY for the EXECUTORCH_LIBRARY macro to work. Otherwise, kernels won't be properly loaded during startup.

@@ -699,7 +705,7 @@ if(EXECUTORCH_BUILD_PYBIND)
PROPERTIES # Assume that this library will be installed in
# `site-packages/executorch/extension/pybindings`, and that
# the custom_ops_aot_lib should be found with relative path.
BUILD_RPATH "$ORIGIN:$ORIGIN/../../examples/models/llama2/custom_ops"
BUILD_RPATH "$ORIGIN:$ORIGIN/../../examples/models/llama2/custom_ops:$ORIGIN/../../examples/models/llama2/aqlm"
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDK

@@ -87,6 +87,7 @@ endif()
# custom ops library
if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
add_subdirectory(custom_ops)
add_subdirectory(aqlm)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Traverse the subdirectory containing the AQLM code.

@@ -129,6 +130,9 @@ list(APPEND link_libraries quantized_kernels quantized_ops_lib)
if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
target_link_options_shared_lib(custom_ops)
list(APPEND link_libraries custom_ops)

target_link_options_shared_lib(aqlm)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Force the linkage of core aqlm ops to the llama_runner executable. This step is NECESSARY for the EXECUTORCH_LIBRARY macro to work. Otherwise, kernels won't be properly loaded during startup.

@@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly copied from examples/models/llama2/custom_ops/CMakeLists.txt.

# list(APPEND aqlm_libs OpenMP::OpenMP_CXX)
# list(APPEND aqlm_libs omp)

add_library(aqlm ${_aqlm__srcs})
Copy link
Owner Author

@BlackSamorez BlackSamorez Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A library containing core AQLM ops and a EXECUTORCH_LIBRARY macro invocation for their automatic registration into executorch runtime when linked with target_link_options_shared_lib.

# Add a AOT library
find_package(Torch CONFIG REQUIRED)
add_library(
aqlm_aot_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/lut_kernel_pytorch.cpp
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A library to be loaded into PyTorch with torch.ops.load_library. Contains TORCH_LIBRARY macro invocations to register and provide implementation for AQLM operations.

self.codes.data = torch.permute(self.codes.data, (1, 0, 2)).contiguous()

def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.ops.aqlm.code2x8_lut_matmat(
Copy link
Owner Author

@BlackSamorez BlackSamorez Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invoke the custom op loaded from lut_kernel.cpp
The op is loaded when importing lut_kernel.py.

self.register_parameter("bias", None)

def transpose_codes(self):
self.codes.data = torch.permute(self.codes.data, (1, 0, 2)).contiguous()
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The codes layout for C++ kernels differs from the CUDA one. We need to do some weights preprocessing.

#include <numeric>
#include <functional>

#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed for TORCH_LIBRARY

#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>

#include <executorch/kernels/optimized/blas/CPUBlas.h>
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ::executorch::cpublas::gemm

namespace torch {
namespace executor {
namespace native {
Tensor& code2x8_lut_matmat_out(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are torch::executor::Tensor, which offer far fewer operations than torch::Tensor.

auto num_input_vectors = std::accumulate(input_sizes.begin(), input_sizes.end(), 1, std::multiplies<int64_t>()) / input_vector_size;

// Allocate LUT
auto lut_data = ctx.allocate_temp(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to manually allocate all the temporary memory we need. One way to do so is to invoke allocate_temp on the RuntimeContext of the operation. Just make sure that the context is provided with a temp allocator.

out_features
);

return out;
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_out operation returns it's first argument.

@@ -489,7 +489,7 @@ def run(self):
"-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON", # add llama sdpa ops to pybindings.
"-DEXECUTORCH_BUILD_KERNELS_CUSTOM_AOT=ON",
]
build_args += ["--target", "custom_ops_aot_lib"]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Build aqlm_aot_lib for the python library.

@@ -569,6 +569,13 @@ def get_ext_modules() -> list[Extension]:
"executorch/examples/models/llama2/custom_ops",
)
)
ext_modules.append(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the compiled dynamic library with AQLM bindings to the pip installation.

dtype=torch.int8,
),
requires_grad=False,
) # [num_in_groups, num_out_groups, num_codebooks]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: different from the usual AQLM layout.

).get();

// A @ B.T
::executorch::cpublas::gemm(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No matmul, so we have to use low-level ops.

@@ -130,7 +130,7 @@ class ExecuTorchLlamaJni
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
runner_->generate(
prompt->toStdString(),
128,
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Increase context length for more meaningful generations.

target_include_directories(
aqlm PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../../include"
)
target_link_libraries(aqlm PUBLIC ${aqlm_libs} -fopenmp -static-openmp)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional flags for OMP to link on Android

cpublas
eigen_blas
quantized_kernels
quantized_ops_lib
-fopenmp
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional flags for OMP to link on Android

@BlackSamorez
Copy link
Owner Author

Update: added OMP for 2.5x speedup on 4 cores.

@larryliu0820
Copy link

@BlackSamorez thanks for working on this, really appreciate the feedback you gave in the blog post. Looking at this PR I have some takeaways, on how to improve ExecuTorch to make this flow easier. Let me know if they make sense!

  • Documentations:
    • README.md should include instructions on how to write CMake for custom kernels.
    • Maybe provide an overview to CMake build system?
  • Packaging:
    • Provide better tools for users like you to use ExecuTorch as a library instead of pulling the source code and modify inside.

I'm curious to learn where did you spend most of your time, in order to make this work?

Comment on lines +53 to +56
values_vec = vmulq_f32(values_vec, scales_vec);
if (bias != nullptr) {
values_vec = vaddq_f32(values_vec, bias_vec);
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably want to issue an FMA (https://arm-software.github.io/acle/neon_intrinsics/advsimd.html#fused-multiply-accumulate) here if bias is not nullptr. I would also recommend generating a separate function (e.g., by adding a template parameter to ignore the bias) to make sure that you get a separate kernel generated for with and without bias so you can be sure not to pay the cost of the test and branch on every iteration.

some amount of loop unrolling is also probably advisable; hopefully the compiler will do that for you, but I would recommend checking the generated assembly.

const int b_alt_stride = 2 * out_features;

for (int input = 0; input < num_inputs; ++input) {
#pragma omp parallel for num_threads(4)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I don't think OMP works on iOS.

Comment on lines +30 to +33
for (int i = 0; i < out_features; ++i) {
output_vec[input * out_features + i] += lut_ptr[b_alt_ptr[i * 2]];
output_vec[input * out_features + i] += lut_ptr[256 + b_alt_ptr[i * 2 + 1]];
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend experimenting with unrolling this loop. With clang, you put #pragma unroll 4 (or whatever unroll count) on the line before the for; Google says (https://gcc.gnu.org/onlinedocs/gcc/Loop-Specific-Pragmas.html) the GCC equivalent would be #pragma GCC unroll 4.

@swolchok
Copy link

one other note: have you compared performance with the AQLM paper's numba kernel? It looks like that setup is using a JIT to specialize over (in_group_size, out_features, in_features, num_codebooks), which should do a better job of exposing optimization opportunities to the compiler; if you find there is a performance gap (e.g., running both on an ARM or x86 server/laptop/whatever) you might want to experiment with templatizing your kernel over a similar suite of parameters.

@BlackSamorez
Copy link
Owner Author

It looks like that setup is using a JIT to specialize over (in_group_size, out_features, in_features, num_codebooks)

No, in reality, it's also using [num_in_groups, num_out_groups, num_codebooks]. The comment with shapes is wrong.
That layout is much faster than [num_out_groups, num_in_groups, num_codebooks]. Mostly because the LUT memory accesses are into a contiguous memory array in the innermost loop when in_features is ~last dim.

@swolchok
Copy link

It looks like that setup is using a JIT to specialize over (in_group_size, out_features, in_features, num_codebooks)

No, in reality, it's also using [num_in_groups, num_out_groups, num_codebooks]. The comment with shapes is wrong. That layout is much faster than [num_out_groups, num_in_groups, num_codebooks]. Mostly because the LUT memory accesses are into a contiguous memory array in the innermost loop when in_features is ~last dim.

I'm not talking so much about the layout as I am talking about specializing over specific values as loop trip counts.

@BlackSamorez
Copy link
Owner Author

specializing over specific values as loop trip counts

Oh, I see. I forgot that I did that for the Numba kernel. Yes, we could templatize the code and instantiate all the Llama shapes there are with an eager fallback. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants