-
Notifications
You must be signed in to change notification settings - Fork 0
AQLM custom kernels for Android #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v0.3.0_branch
Are you sure you want to change the base?
Changes from 8 commits
a562b76
7c3e191
db94aba
97c86e3
008f9c3
0645fee
4389601
cade4c3
3ad5904
94d4544
3156fc2
91b1378
8d1053b
b478f06
2bfd66f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -467,6 +467,9 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM) | |
add_subdirectory( | ||
${CMAKE_CURRENT_SOURCE_DIR}/examples/models/llama2/custom_ops | ||
) | ||
add_subdirectory( | ||
${CMAKE_CURRENT_SOURCE_DIR}/examples/models/llama2/aqlm | ||
) | ||
endif() | ||
|
||
if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED) | ||
|
@@ -633,13 +636,16 @@ if(EXECUTORCH_BUILD_PYBIND) | |
# TODO(larryliu): Fix macOS 2 dylibs having 2 sets of static variables issue | ||
if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT AND NOT APPLE) | ||
list(APPEND _dep_libs custom_ops_aot_lib) | ||
list(APPEND _dep_libs aqlm_aot_lib) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add torch bindings for AQLM to the |
||
endif() | ||
# TODO(laryliu): Fix linux duplicate registation problem. In GH CI worker | ||
# libcustom_ops.a doesn't dedup with the one indirectly linked from | ||
# libcustom_ops_aot_lib.a | ||
if(EXECUTORCH_BUILD_KERNELS_CUSTOM AND APPLE) | ||
target_link_options_shared_lib(custom_ops) | ||
list(APPEND _dep_libs custom_ops) | ||
target_link_options_shared_lib(aqlm) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Force the linkage of core |
||
list(APPEND _dep_libs aqlm) | ||
endif() | ||
# compile options for pybind | ||
set(_pybind_compile_options | ||
|
@@ -699,7 +705,7 @@ if(EXECUTORCH_BUILD_PYBIND) | |
PROPERTIES # Assume that this library will be installed in | ||
# `site-packages/executorch/extension/pybindings`, and that | ||
# the custom_ops_aot_lib should be found with relative path. | ||
BUILD_RPATH "$ORIGIN:$ORIGIN/../../examples/models/llama2/custom_ops" | ||
BUILD_RPATH "$ORIGIN:$ORIGIN/../../examples/models/llama2/custom_ops:$ORIGIN/../../examples/models/llama2/aqlm" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IDK |
||
) | ||
endif() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,6 +87,7 @@ endif() | |
# custom ops library | ||
if(EXECUTORCH_BUILD_KERNELS_CUSTOM) | ||
add_subdirectory(custom_ops) | ||
add_subdirectory(aqlm) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Traverse the subdirectory containing the AQLM code. |
||
endif() | ||
|
||
# llama_runner library | ||
|
@@ -129,6 +130,9 @@ list(APPEND link_libraries quantized_kernels quantized_ops_lib) | |
if(EXECUTORCH_BUILD_KERNELS_CUSTOM) | ||
target_link_options_shared_lib(custom_ops) | ||
list(APPEND link_libraries custom_ops) | ||
|
||
target_link_options_shared_lib(aqlm) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Force the linkage of core |
||
list(APPEND link_libraries aqlm) | ||
endif() | ||
|
||
set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mostly copied from |
||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
cmake_minimum_required(VERSION 3.19) | ||
|
||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) | ||
if(NOT CMAKE_CXX_STANDARD) | ||
set(CMAKE_CXX_STANDARD 17) | ||
endif() | ||
|
||
if(NOT PYTHON_EXECUTABLE) | ||
set(PYTHON_EXECUTABLE python3) | ||
endif() | ||
|
||
# Source root directory for executorch. | ||
if(NOT EXECUTORCH_ROOT) | ||
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) | ||
endif() | ||
|
||
set(_common_compile_options -Wno-deprecated-declarations -fPIC) | ||
|
||
include(${EXECUTORCH_ROOT}/build/Utils.cmake) | ||
include(${EXECUTORCH_ROOT}/build/Codegen.cmake) | ||
|
||
# | ||
# The `_<target>_srcs` lists are defined by including ${EXECUTORCH_SRCS_FILE}. | ||
# | ||
set(EXECUTORCH_SRCS_FILE | ||
"${CMAKE_CURRENT_BINARY_DIR}/../../../../executorch_srcs.cmake" | ||
) | ||
|
||
extract_sources(${EXECUTORCH_SRCS_FILE}) | ||
|
||
include(${EXECUTORCH_SRCS_FILE}) | ||
|
||
# Let files say "include <executorch/path/to/header.h>". | ||
set(_common_include_directories ${EXECUTORCH_ROOT}/..) | ||
|
||
# Custom op libraries | ||
set(aqlm_libs executorch_no_prim_ops) | ||
list(APPEND aqlm_libs pthreadpool) | ||
list(APPEND aqlm_libs cpuinfo) | ||
list(APPEND aqlm_libs cpublas) | ||
list(APPEND aqlm_libs eigen_blas) | ||
|
||
set(_aqlm__srcs examples/models/llama2/aqlm/lut_kernel.h examples/models/llama2/aqlm/lut_kernel.cpp) | ||
list(TRANSFORM _aqlm__srcs PREPEND "${EXECUTORCH_ROOT}/") | ||
|
||
|
||
|
||
message("HERE: AQLM SOURCES: ${_aqlm__srcs}") | ||
|
||
# TODO: Consider moving xnnpack/threadpool in a separate lib since it's now used | ||
# by custom ops too. | ||
if(NOT EXECUTORCH_BUILD_XNNPACK) | ||
list( | ||
APPEND | ||
_aqlm__srcs | ||
"${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool.cpp" | ||
"${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool_guard.cpp" | ||
) | ||
else() | ||
list(APPEND aqlm_libs xnnpack_backend) | ||
endif() | ||
|
||
# find_package(OpenMP REQUIRED) | ||
# list(APPEND aqlm_libs OpenMP::OpenMP_CXX) | ||
# list(APPEND aqlm_libs omp) | ||
|
||
add_library(aqlm ${_aqlm__srcs}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A library containing core AQLM ops and a |
||
|
||
# Enable optimization | ||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") # -O3 -fopenmp | ||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") # -O3 -fopenmp | ||
|
||
target_include_directories(aqlm PUBLIC "${_common_include_directories}") | ||
target_include_directories( | ||
aqlm PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../../include" | ||
) | ||
target_link_libraries(aqlm PUBLIC ${aqlm_libs}) | ||
|
||
target_compile_options( | ||
aqlm PUBLIC ${_common_compile_options} -DET_USE_THREADPOOL # ${OpenMP_CXX_FLAGS} | ||
) | ||
|
||
install(TARGETS aqlm DESTINATION lib) | ||
|
||
if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT) | ||
# Add a AOT library | ||
find_package(Torch CONFIG REQUIRED) | ||
add_library( | ||
aqlm_aot_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/lut_kernel_pytorch.cpp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A library to be loaded into PyTorch with |
||
) | ||
target_include_directories( | ||
aqlm_aot_lib PUBLIC "${_common_include_directories}" | ||
) | ||
target_include_directories( | ||
aqlm_aot_lib | ||
PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../../include" | ||
) | ||
target_link_libraries(aqlm_aot_lib PUBLIC aqlm torch) | ||
target_compile_options( | ||
aqlm_aot_lib PUBLIC -Wno-deprecated-declarations -fPIC -frtti | ||
-fexceptions | ||
) | ||
|
||
install(TARGETS aqlm_aot_lib DESTINATION lib) | ||
endif() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
""" AQLM 2x8 Linear""" | ||
import torch | ||
import torch.nn as nn | ||
|
||
class Aqlm2x8Linear(nn.Module): | ||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
in_group_size: int, | ||
out_group_size: int, | ||
num_codebooks: int, | ||
nbits_per_codebook: int, | ||
bias=True, | ||
device=None, | ||
dtype=None, | ||
): | ||
factory_kwargs = {"device": device, "dtype": dtype} | ||
super().__init__() | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
|
||
assert self.in_features % in_group_size == 0 | ||
assert self.out_features % out_group_size == 0 | ||
num_out_groups = out_features // out_group_size | ||
num_in_groups = in_features // in_group_size | ||
self.out_group_size, self.in_group_size = out_group_size, in_group_size | ||
self.num_codebooks = num_codebooks | ||
self.nbits_per_codebook = nbits_per_codebook | ||
self.codebook_size = 2**nbits_per_codebook | ||
|
||
# CODES & CODEBOOKS | ||
self.codebooks = nn.Parameter( | ||
torch.empty((num_codebooks, self.codebook_size, out_group_size, in_group_size), **factory_kwargs) * 2 - 1, | ||
requires_grad=False, | ||
) # [num_codebooks, codebook_size, out_group_size, in_group_size] | ||
self.codes = nn.Parameter( | ||
torch.empty( | ||
(num_out_groups, num_in_groups, num_codebooks), | ||
device=device, | ||
dtype=torch.int8, | ||
), | ||
requires_grad=False, | ||
) # [num_out_groups, num_in_groups, num_codebooks] | ||
|
||
# SCALES | ||
self.scales = nn.Parameter( | ||
torch.empty((num_out_groups, 1, 1, 1), **factory_kwargs), requires_grad=False | ||
) # [num_out_groups, 1, 1, 1] | ||
|
||
# BIAS | ||
if bias: | ||
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs), requires_grad=False) | ||
else: | ||
self.register_parameter("bias", None) | ||
|
||
def transpose_codes(self): | ||
self.codes.data = torch.permute(self.codes.data, (1, 0, 2)).contiguous() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The codes layout for C++ kernels differs from the CUDA one. We need to do some weights preprocessing. |
||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
return torch.ops.aqlm.code2x8_lut_matmat( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Invoke the custom op loaded from |
||
input, self.codes, self.codebooks, self.scales, bias=self.bias | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
#include "lut_kernel.h" | ||
|
||
#include <numeric> | ||
#include <functional> | ||
|
||
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needed for |
||
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h> | ||
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h> | ||
|
||
#include <executorch/kernels/optimized/blas/CPUBlas.h> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For |
||
|
||
|
||
template<typename fp_dtype> | ||
void quadruple_for( | ||
int num_inputs, | ||
int num_input_groups, const fp_dtype* __restrict__ lut, | ||
int out_features, const uint8_t* __restrict__ b_alt, | ||
fp_dtype* __restrict__ output_vec | ||
) | ||
{ | ||
std::memset(output_vec, 0, num_inputs * out_features * sizeof(fp_dtype)); | ||
|
||
const int lut_stride = num_input_groups * 2 * 256; | ||
const int b_alt_stride = 2 * out_features; | ||
|
||
for (int input = 0; input < num_inputs; ++input) { | ||
for (int j = 0; j < num_input_groups; ++j) { | ||
const fp_dtype* lut_ptr = lut + input * lut_stride + j * 2 * 256; | ||
const uint8_t* b_alt_ptr = b_alt + j * b_alt_stride; | ||
|
||
for (int i = 0; i < out_features; ++i) { | ||
output_vec[input * out_features + i] += lut_ptr[b_alt_ptr[i * 2]]; | ||
output_vec[input * out_features + i] += lut_ptr[256 + b_alt_ptr[i * 2 + 1]]; | ||
} | ||
Comment on lines
+30
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would recommend experimenting with unrolling this loop. With clang, you put |
||
} | ||
} | ||
} | ||
|
||
#if defined(__aarch64__) && defined(__ARM_NEON) | ||
#include <arm_neon.h> | ||
void row_wise_scaling_and_bias( | ||
float* __restrict__ out, | ||
const float* __restrict__ scales, const float* __restrict__ bias, | ||
int num_input_vectors, int out_features | ||
) { | ||
for (int j = 0; j < out_features; j += 4) { | ||
float32x4_t scales_vec = vld1q_f32(scales + j); | ||
float32x4_t bias_vec; | ||
if (bias != nullptr){ | ||
bias_vec = vld1q_f32(bias + j); | ||
} | ||
for (int i=0; i < num_input_vectors; ++i) { | ||
float32x4_t values_vec = vld1q_f32(out + i * out_features + j); | ||
values_vec = vmulq_f32(values_vec, scales_vec); | ||
if (bias != nullptr) { | ||
values_vec = vaddq_f32(values_vec, bias_vec); | ||
} | ||
Comment on lines
+53
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you probably want to issue an FMA (https://arm-software.github.io/acle/neon_intrinsics/advsimd.html#fused-multiply-accumulate) here if bias is not nullptr. I would also recommend generating a separate function (e.g., by adding a template parameter to ignore the bias) to make sure that you get a separate kernel generated for with and without bias so you can be sure not to pay the cost of the test and branch on every iteration. some amount of loop unrolling is also probably advisable; hopefully the compiler will do that for you, but I would recommend checking the generated assembly. |
||
vst1q_f32(out + i * out_features + j, values_vec); | ||
} | ||
} | ||
} | ||
#else | ||
void row_wise_scaling_and_bias( | ||
float* __restrict__ out, | ||
const float* __restrict__ scales, const float* __restrict__ bias, | ||
int num_input_vectors, int out_features | ||
) { | ||
for (int j = 0; j < out_features; ++j) { | ||
float scale_value = scales[j]; | ||
float bias_value; | ||
if (bias != nullptr){ | ||
bias_value = bias[j]; | ||
} | ||
for (int i=0; i < num_input_vectors; ++i) { | ||
out[i * out_features + j] *= scale_value; | ||
if (bias != nullptr) { | ||
out[i * out_features + j] += bias_value; | ||
} | ||
} | ||
} | ||
} | ||
#endif | ||
|
||
namespace torch { | ||
namespace executor { | ||
namespace native { | ||
Tensor& code2x8_lut_matmat_out( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those are |
||
RuntimeContext& ctx, | ||
const Tensor& input, | ||
const Tensor& codes, | ||
const Tensor& codebooks, | ||
const Tensor& scales, | ||
const optional<Tensor>& bias, | ||
Tensor& out | ||
) { | ||
auto input_sizes = input.sizes(); | ||
auto out_features = codes.size(1) * codebooks.size(2); | ||
auto input_vector_size = input.size(input.dim() - 1); | ||
auto num_input_vectors = std::accumulate(input_sizes.begin(), input_sizes.end(), 1, std::multiplies<int64_t>()) / input_vector_size; | ||
|
||
// Allocate LUT | ||
auto lut_data = ctx.allocate_temp( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to manually allocate all the temporary memory we need. One way to do so is to invoke |
||
4 * num_input_vectors * input_vector_size / 8 * codebooks.size(0) * codebooks.size(1) | ||
).get(); | ||
|
||
// A @ B.T | ||
::executorch::cpublas::gemm( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No matmul, so we have to use low-level ops. |
||
::executorch::cpublas::TransposeType::Transpose, | ||
::executorch::cpublas::TransposeType::NoTranspose, | ||
(int64_t)codebooks.size(0) * codebooks.size(1), // B rows | ||
(int64_t)num_input_vectors * input_vector_size / 8, // A rows | ||
(int64_t)8, // MatMul dim size | ||
1.f, | ||
(float*)codebooks.const_data_ptr(), (int64_t)8, | ||
(float*)input.const_data_ptr(), (int64_t)8, | ||
0.f, | ||
(float*)lut_data, (int64_t)codebooks.size(0) * codebooks.size(1) | ||
); | ||
|
||
// Do lookup matmul | ||
quadruple_for<float>( | ||
num_input_vectors, | ||
input_vector_size / 8, | ||
(const float*)lut_data, | ||
out_features, | ||
(const uint8_t*)codes.const_data_ptr(), | ||
(float*)out.mutable_data_ptr() | ||
); | ||
|
||
const float* bias_ptr = nullptr; | ||
if (bias.has_value()) { | ||
bias_ptr = bias.value().const_data_ptr<float>(); | ||
} | ||
|
||
row_wise_scaling_and_bias( | ||
out.mutable_data_ptr<float>(), | ||
scales.const_data_ptr<float>(), | ||
bias_ptr, | ||
num_input_vectors, | ||
out_features | ||
); | ||
|
||
return out; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} | ||
} // namespace native | ||
} // namespace executor | ||
} // namespace torch | ||
|
||
EXECUTORCH_LIBRARY(aqlm, "code2x8_lut_matmat.out", torch::executor::native::code2x8_lut_matmat_out); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A macro to register the operation into the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instruct CMake to traverse the subdirectory containing the AQLM kernels. Keeping the directory in the same CMake tree makes it easier to link those cutom operators libs.