Skip to content

Introduce New Lookup-Table(LUT)-Based Matrix Multiplication Method #10181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 237 additions & 9 deletions convert_hf_to_gguf.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ option(GGML_SYCL "ggml: use SYCL"
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
"ggml: sycl target device")
option(GGML_TMAC "ggml: use TMAC" OFF)
option(GGML_TMAC_SYSLIB "ggml: use TMAC system library" OFF)
option(GGML_TMAC_TVM_THREADPOOL "ggml: use TVM threadpool for TMAC" OFF)

# extra artifacts
option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})
Expand Down
42 changes: 42 additions & 0 deletions ggml/include/ggml-tmac.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include "ggml.h"
#include "ggml-backend.h"

#ifdef __ARM_NEON
#include <arm_neon.h>
typedef float16_t tmac_float_type;
#else
typedef float tmac_float_type;
#endif

#ifdef __cplusplus
extern "C" {
#endif

struct tmac_tensor_extra {
int lut_scales_size;
int scales_size;
int n_tile_num;
uint8_t * qweights;
tmac_float_type * scales;
};

GGML_API void ggml_tmac_init(void);
GGML_API void ggml_tmac_free(void);
// src0->type == Q4_0/IQ2_XXS/IQ3_XXS
// T-MAC currently only supports BitNet quantization or GPTQ-like quantization (only scales, without zeros)
// If use i-quantization gguf models, the results will be wrong
// TODO: add customized block types Q2_0/Q3_0
GGML_API bool ggml_tmac_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
GGML_API size_t ggml_tmac_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
GGML_API void ggml_tmac_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits);
GGML_API void ggml_tmac_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits);
GGML_API void ggml_tmac_transform_tensor(struct ggml_tensor * tensor);
GGML_API int ggml_tmac_get_type_bits(enum ggml_type type);
GGML_API void ggml_tmac_set_n_threads(int n_threads);
GGML_API size_t ggml_tmac_get_nbytes(const struct ggml_tensor * tensor);

#ifdef __cplusplus
}
#endif
4 changes: 4 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ extern "C" {
GGML_TYPE_Q4_0_8_8 = 33,
GGML_TYPE_TQ1_0 = 34,
GGML_TYPE_TQ2_0 = 35,
GGML_TYPE_I1 = 36,
GGML_TYPE_I2 = 37,
GGML_TYPE_I3 = 38,
GGML_TYPE_I4 = 39,
GGML_TYPE_COUNT,
};

Expand Down
61 changes: 60 additions & 1 deletion ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,42 @@ if (GGML_KOMPUTE)
endif()
endif()

if (GGML_TMAC)
find_package(TMAC)

if (TMAC_FOUND)
message(STATUS "TMAC found")

list(APPEND GGML_CDEF_PUBLIC GGML_USE_TMAC)

set(GGML_HEADERS_TMAC ../include/ggml-tmac.h)
set(GGML_SOURCES_TMAC ggml-tmac.cpp)

link_directories(${TMAC_LIB_DIR})
file(COPY ${TMAC_LIB_DIR}/kcfg.ini DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
# TODO: link t_mac_object when GGML_TMAC_SYSLIB

if (GGML_TMAC_TVM_THREADPOOL)
add_compile_definitions(TMAC_USE_TVM_THREADPOOL)
set(GGML_EXTRA_LIBS_PRIVATE ${GGML_EXTRA_LIBS_PRIVATE} t_mac)
else()
if ((NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") OR
(NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang"))
message(FATAL_ERROR "Clang is required for T-MAC compilation")
endif()

set(GGML_EXTRA_LIBS_PRIVATE ${GGML_EXTRA_LIBS_PRIVATE} t_mac_no_tvm)
set(GGML_SOURCES_TMAC ${GGML_SOURCES_TMAC} ${TMAC_KERNELS_SOURCE})
endif()

if (GGML_TMAC_RECHUNK)
add_compile_definitions(TMAC_RECHUNK)
endif()
else()
message(WARNING "TMAC not found")
endif()
endif()

if (GGML_CPU_HBM)
find_library(memkind memkind REQUIRED)

Expand Down Expand Up @@ -1170,6 +1206,26 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR
# Raspberry Pi 3, 4, Zero 2 (32-bit)
list(APPEND ARCH_FLAGS -mno-unaligned-access)
endif()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64" AND GGML_TMAC AND TMAC_FOUND)
# We need fullfp16 for T-MAC
# TODO: we need to simplify this logic through check_cxx_source_compiles or Presets?
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8)
if (GGML_COMPILER_SUPPORT_MATMUL_INT8)
# Device with armv8.7a+ cpu, e.g., WSL on Surface Laptop 7
# based on arm64-windows-llvm.cmake
list(APPEND ARCH_FLAGS -march=armv8.7-a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only)
add_compile_definitions(__ARM_FEATURE_MATMUL_INT8)
else ()
# Jetson AGX Orin, Raspberry Pi 5
list(APPEND ARCH_FLAGS -march=armv8.2a+fp16)
endif ()
endif()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ARM64" AND GGML_TMAC AND TMAC_FOUND)
# ARM Windows with LLVM clang GNU interface
# We need fullfp16 for T-MAC
# TODO: check_cxx_source_compiles
list(APPEND ARCH_FLAGS -march=armv8.2a+fp16)
endif()
if (GGML_SVE)
list(APPEND ARCH_FLAGS -march=armv8.6-a+sve)
endif()
Expand All @@ -1184,7 +1240,9 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
# TODO: improve, should not reference files from the parent folder
include(../cmake/FindSIMD.cmake)
endif ()
if (GGML_AVX512)
# Can't use GGML_AVX512 with Clang for MSVC
# with error: conflicting types for '_m_prefetchw
if (GGML_AVX512 AND (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") AND (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang"))
list(APPEND ARCH_FLAGS /arch:AVX512)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
Expand Down Expand Up @@ -1388,6 +1446,7 @@ add_library(ggml
${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
${GGML_SOURCES_AMX} ${GGML_HEADERS_AMX}
${GGML_SOURCES_CANN} ${GGML_HEADERS_CANN}
${GGML_SOURCES_TMAC} ${GGML_HEADERS_TMAC}
ggml-aarch64.c ggml-aarch64.h
)

Expand Down
183 changes: 183 additions & 0 deletions ggml/src/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@
#include <Accelerate/Accelerate.h>
#endif

#if defined(GGML_USE_TMAC)
#include "ggml-tmac.h"
#endif

// floating point type used to accumulate sums
typedef double ggml_float;

Expand Down Expand Up @@ -423,6 +427,26 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
[GGML_TYPE_I1] = {
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
[GGML_TYPE_I2] = {
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
[GGML_TYPE_I3] = {
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
[GGML_TYPE_I4] = {
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
};

const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
Expand Down Expand Up @@ -7478,6 +7502,156 @@ static void ggml_compute_forward_mul_mat(
UseGgmlGemm1:;
#endif

// TODO: Refactor t-mac as ggml-backend,
// as ggml-blas.cpp has been moved to backend
#if defined(GGML_USE_TMAC)
if (ggml_tmac_can_mul_mat(src0, src1, dst)) {
const int bits = ggml_tmac_get_type_bits(type);
// src0: weight, ne00 = k, ne01 = n
// src1: activation, ne10 = k, ne11 = m
char * wdata = params->wdata;

struct tmac_tensor_extra * wt = src0->extra;
char * cur_wdata = wdata;
tmac_float_type * tmac_f_ptr = wdata;
if (sizeof(tmac_float_type) == 2) {
cur_wdata = wdata + MAX(ne10, ne01) * ne11 * sizeof(tmac_float_type);
};
int8_t * qlut = cur_wdata;
tmac_float_type * lut_scales = (tmac_float_type *) (qlut + ne10 * ne11 * 4);
tmac_float_type * lut_biases = (tmac_float_type *) (lut_scales + wt->lut_scales_size * ne11);

GGML_ASSERT(src1->type == GGML_TYPE_F32);
tmac_float_type * act_input;
if (sizeof(tmac_float_type) == 2) {
act_input = tmac_f_ptr;
} else {
act_input = src1->data;
}
for (int ine11 = ith; ine11 < ne11; ine11 += nth) {
if (sizeof(tmac_float_type) == 2) {
ggml_fp32_to_fp16_row((const float *) src1->data + ne10 * ine11, act_input + ne10 * ine11, ne10);
}
ggml_tmac_mul_mat_task_init(act_input + ne10 * ine11,
qlut + ne10 * ine11 * 4,
lut_scales + wt->lut_scales_size * ine11,
lut_biases + wt->lut_scales_size * ine11,
ne01, ne00, 1, bits);
}

if (ith == 0) {
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
atomic_store_explicit(&params->threadpool->current_chunk, nth, memory_order_relaxed);
}

ggml_barrier(params->threadpool);

tmac_float_type * act_output;
if (sizeof(tmac_float_type) == 2) {
act_output = tmac_f_ptr;
} else {
act_output = dst->data;
}
// TODO: remove TVM threadpool if ensuring unused
#if defined(TMAC_USE_TVM_THREADPOOL)
if (ith != 0) {
return;
}
// TODO: schedule ne11(m) in T-MAC
for (int ine11 = 0; ine11 < ne11; ine11++) {
const int qlut_offset = ne10 * ine11 * 4;
const int lut_scales_offset = wt->lut_scales_size * ine11;
const int dst_offset = ne0 * ine11;

ggml_tmac_mul_mat_task_compute(wt->qweights,
wt->scales,
qlut + qlut_offset,
lut_scales + lut_scales_offset,
lut_biases + lut_scales_offset,
act_output + dst_offset,
ne01, ne00, 1, bits);
}
if (sizeof(tmac_float_type) == 2) {
ggml_fp16_to_fp32_row(tmac_f_ptr, dst->data, ne00 * ne01);
}
#else // #if defined(TMAC_USE_TVM_THREADPOOL)
const int n_tile_num = wt->n_tile_num;
// Currently, T-MAC requires ne0 devisible by n_tile_num
GGML_ASSERT(ne0 % n_tile_num == 0);

const int64_t w_size = ne00 * ne01 * bits / 8;
const int64_t w_chunk_size = w_size / n_tile_num;

const int64_t nr0 = ne0;
const int64_t nr1 = ne1 * ne2 * ne3;

// Adopt the same style with current llama.cpp impl
// But different chunk size for 0/1 dim.
// No scrap.
const int chunk_size0 = ne0 / n_tile_num;
const int chunk_size1 = 8; // TODO: tune in T-MAC

// nchunk0 == n_tile_num
int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
int64_t nchunk1 = (nr1 + chunk_size1 - 1) / chunk_size1;

int64_t dr0 = chunk_size0;
int64_t dr1 = chunk_size1;
#if defined(TMAC_RECHUNK)
// Rechunk
if ((nchunk1 == 1) && (nchunk0 > nth * 4)) {
// dr0 should be divisible by chunk_size0
dr0 = (ne0 / (nth * 4) / chunk_size0) * chunk_size0;
nchunk0 = (nr0 + dr0 - 1) / dr0;
}
#endif

int current_chunk = ith;

while (current_chunk < nchunk0 * nchunk1) {
const int64_t ith0 = current_chunk % nchunk0;
const int64_t ith1 = current_chunk / nchunk0;

const int64_t ir0_start = dr0 * ith0;
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);

const int64_t ir1_start = dr1 * ith1;
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);

// inline ggml_compute_forward_mul_mat_one_chunk here for simplicity
for (int64_t ichunk0 = ir0_start / chunk_size0; ichunk0 < ir0_end / chunk_size0; ichunk0++) {
const int64_t w_offset = ichunk0 * w_chunk_size;
const int64_t scales_offset = ichunk0 * wt->scales_size / n_tile_num;

for (int64_t ine11 = ir1_start; ine11 < ir1_end; ine11++) {
const int64_t qlut_offset = ne10 * ine11 * 4;
const int64_t lut_scales_offset = wt->lut_scales_size * ine11;
const int64_t dst_offset = ne0 * ine11 + ichunk0 * chunk_size0;

ggml_tmac_mul_mat_task_compute(wt->qweights + w_offset,
wt->scales + scales_offset,
qlut + qlut_offset,
lut_scales + lut_scales_offset,
lut_biases + lut_scales_offset,
act_output + dst_offset,
chunk_size0, ne00, 1, bits);
if (sizeof(tmac_float_type) == 2) {
ggml_fp16_to_fp32_row(act_output + dst_offset, (float *) dst->data + dst_offset, chunk_size0);
}
}
}

if (nth >= nchunk0 * nchunk1) {
break;
}

current_chunk = atomic_fetch_add_explicit(&params->threadpool->current_chunk, 1, memory_order_relaxed);
}
#endif // #if defined(TMAC_USE_TVM_THREADPOOL)
return;
} // if (ggml_tmac_can_mul_mat(src0, src1, dst))
#endif // #if defined(GGML_USE_TMAC)

if (src1->type != vec_dot_type) {
char * wdata = params->wdata;

Expand Down Expand Up @@ -9123,6 +9297,10 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_F64:
case GGML_TYPE_I1:
case GGML_TYPE_I2:
case GGML_TYPE_I3:
case GGML_TYPE_I4:
case GGML_TYPE_COUNT:
{
GGML_ABORT("fatal error");
Expand Down Expand Up @@ -13172,6 +13350,11 @@ struct ggml_cplan ggml_graph_plan(
{
const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;

#if defined(GGML_USE_TMAC)
if (ggml_tmac_can_mul_mat(node->src[0], node->src[1], node)) {
cur = ggml_tmac_mul_mat_get_wsize(node->src[0], node->src[1], node);
} else
#endif
if (node->src[1]->type != vec_dot_type) {
cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
}
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -15741,6 +15741,12 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_I64:
// nothing to validate
break;
case GGML_TYPE_I1:
case GGML_TYPE_I2:
case GGML_TYPE_I3:
case GGML_TYPE_I4:
// nothing to validate
break;
default:
{
fprintf(stderr, "%s: invalid type %d\n", __func__, type);
Expand Down
Loading