From a562b76eba3d2da6fe29d1d6f74ed17ab0a05a46 Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Sun, 4 Aug 2024 17:05:02 +0300 Subject: [PATCH 01/15] aqlm support --- CMakeLists.txt | 5 + examples/models/llama2/CMakeLists.txt | 4 + examples/models/llama2/aqlm/CmakeLists.txt | 103 ++++++++++++++++ examples/models/llama2/aqlm/__init__.py | 0 examples/models/llama2/aqlm/linear.py | 63 ++++++++++ examples/models/llama2/aqlm/lut_kernel.cpp | 114 ++++++++++++++++++ examples/models/llama2/aqlm/lut_kernel.h | 21 ++++ examples/models/llama2/aqlm/lut_kernel.py | 26 ++++ .../models/llama2/aqlm/lut_kernel_pytorch.cpp | 82 +++++++++++++ examples/models/llama2/aqlm/utils.py | 98 +++++++++++++++ examples/models/llama2/export_llama_lib.py | 2 +- .../llama2/source_transformation/quantize.py | 26 ++++ extension/module/module.cpp | 5 +- pyproject.toml | 6 +- setup.py | 9 +- 15 files changed, 558 insertions(+), 6 deletions(-) create mode 100644 examples/models/llama2/aqlm/CmakeLists.txt create mode 100644 examples/models/llama2/aqlm/__init__.py create mode 100644 examples/models/llama2/aqlm/linear.py create mode 100644 examples/models/llama2/aqlm/lut_kernel.cpp create mode 100644 examples/models/llama2/aqlm/lut_kernel.h create mode 100644 examples/models/llama2/aqlm/lut_kernel.py create mode 100644 examples/models/llama2/aqlm/lut_kernel_pytorch.cpp create mode 100644 examples/models/llama2/aqlm/utils.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 60992909447..a4814f98ba5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -467,6 +467,9 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM) add_subdirectory( ${CMAKE_CURRENT_SOURCE_DIR}/examples/models/llama2/custom_ops ) + add_subdirectory( + ${CMAKE_CURRENT_SOURCE_DIR}/examples/models/llama2/aqlm + ) endif() if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED) @@ -633,6 +636,7 @@ if(EXECUTORCH_BUILD_PYBIND) # TODO(larryliu): Fix macOS 2 dylibs having 2 sets of static variables issue if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT AND NOT APPLE) list(APPEND _dep_libs custom_ops_aot_lib) + list(APPEND _dep_libs aqlm) endif() # TODO(laryliu): Fix linux duplicate registation problem. In GH CI worker # libcustom_ops.a doesn't dedup with the one indirectly linked from @@ -640,6 +644,7 @@ if(EXECUTORCH_BUILD_PYBIND) if(EXECUTORCH_BUILD_KERNELS_CUSTOM AND APPLE) target_link_options_shared_lib(custom_ops) list(APPEND _dep_libs custom_ops) + list(APPEND _dep_libs aqlm) endif() # compile options for pybind set(_pybind_compile_options diff --git a/examples/models/llama2/CMakeLists.txt b/examples/models/llama2/CMakeLists.txt index 5044a5ce9bd..44eeed957df 100644 --- a/examples/models/llama2/CMakeLists.txt +++ b/examples/models/llama2/CMakeLists.txt @@ -87,6 +87,7 @@ endif() # custom ops library if(EXECUTORCH_BUILD_KERNELS_CUSTOM) add_subdirectory(custom_ops) + add_subdirectory(aqlm) endif() # llama_runner library @@ -129,6 +130,9 @@ list(APPEND link_libraries quantized_kernels quantized_ops_lib) if(EXECUTORCH_BUILD_KERNELS_CUSTOM) target_link_options_shared_lib(custom_ops) list(APPEND link_libraries custom_ops) + + target_link_options_shared_lib(aqlm) + list(APPEND link_libraries aqlm) endif() set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack) diff --git a/examples/models/llama2/aqlm/CmakeLists.txt b/examples/models/llama2/aqlm/CmakeLists.txt new file mode 100644 index 00000000000..1af90c0b757 --- /dev/null +++ b/examples/models/llama2/aqlm/CmakeLists.txt @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +if(NOT PYTHON_EXECUTABLE) + set(PYTHON_EXECUTABLE python3) +endif() + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) +endif() + +set(_common_compile_options -Wno-deprecated-declarations -fPIC) + +include(${EXECUTORCH_ROOT}/build/Utils.cmake) +include(${EXECUTORCH_ROOT}/build/Codegen.cmake) + +# +# The `__srcs` lists are defined by including ${EXECUTORCH_SRCS_FILE}. +# +set(EXECUTORCH_SRCS_FILE + "${CMAKE_CURRENT_BINARY_DIR}/../../../../executorch_srcs.cmake" +) + +extract_sources(${EXECUTORCH_SRCS_FILE}) + +include(${EXECUTORCH_SRCS_FILE}) + +# Let files say "include ". +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +# Custom op libraries +set(aqlm_libs executorch_no_prim_ops) +list(APPEND aqlm_libs pthreadpool) +list(APPEND aqlm_libs cpuinfo) +list(APPEND aqlm_libs cpublas) +list(APPEND aqlm_libs eigen_blas) + +set(_aqlm__srcs examples/models/llama2/aqlm/lut_kernel.h examples/models/llama2/aqlm/lut_kernel.cpp) +list(TRANSFORM _aqlm__srcs PREPEND "${EXECUTORCH_ROOT}/") + + + +message("HERE: AQLM SOURCES: ${_aqlm__srcs}") + +# TODO: Consider moving xnnpack/threadpool in a separate lib since it's now used +# by custom ops too. +if(NOT EXECUTORCH_BUILD_XNNPACK) + list( + APPEND + _aqlm__srcs + "${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool_guard.cpp" + ) +else() + list(APPEND aqlm_libs xnnpack_backend) +endif() + +add_library(aqlm ${_aqlm__srcs}) + +target_include_directories(aqlm PUBLIC "${_common_include_directories}") +target_include_directories( + aqlm PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../../include" +) +target_link_libraries(aqlm PUBLIC ${aqlm_libs}) + +target_compile_options( + aqlm PUBLIC ${_common_compile_options} -DET_USE_THREADPOOL +) + +install(TARGETS aqlm DESTINATION lib) + +if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT) + # Add a AOT library + find_package(Torch CONFIG REQUIRED) + add_library( + aqlm_aot_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/lut_kernel_pytorch.cpp + ) + target_include_directories( + aqlm_aot_lib PUBLIC "${_common_include_directories}" + ) + target_include_directories( + aqlm_aot_lib + PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../../include" + ) + target_link_libraries(aqlm_aot_lib PUBLIC aqlm torch) + target_compile_options( + aqlm_aot_lib PUBLIC -Wno-deprecated-declarations -fPIC -frtti + -fexceptions + ) + + install(TARGETS aqlm_aot_lib DESTINATION lib) +endif() diff --git a/examples/models/llama2/aqlm/__init__.py b/examples/models/llama2/aqlm/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/models/llama2/aqlm/linear.py b/examples/models/llama2/aqlm/linear.py new file mode 100644 index 00000000000..25c45f005be --- /dev/null +++ b/examples/models/llama2/aqlm/linear.py @@ -0,0 +1,63 @@ +""" AQLM 2x8 Linear""" +import torch +import torch.nn as nn + +class Aqlm2x8Linear(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + in_group_size: int, + out_group_size: int, + num_codebooks: int, + nbits_per_codebook: int, + bias=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + + assert self.in_features % in_group_size == 0 + assert self.out_features % out_group_size == 0 + num_out_groups = out_features // out_group_size + num_in_groups = in_features // in_group_size + self.out_group_size, self.in_group_size = out_group_size, in_group_size + self.num_codebooks = num_codebooks + self.nbits_per_codebook = nbits_per_codebook + self.codebook_size = 2**nbits_per_codebook + + # CODES & CODEBOOKS + self.codebooks = nn.Parameter( + torch.empty((num_codebooks, self.codebook_size, out_group_size, in_group_size), **factory_kwargs) * 2 - 1, + requires_grad=False, + ) # [num_codebooks, codebook_size, out_group_size, in_group_size] + self.codes = nn.Parameter( + torch.empty( + (num_out_groups, num_in_groups, num_codebooks), + device=device, + dtype=torch.int8, + ), + requires_grad=False, + ) # [num_out_groups, num_in_groups, num_codebooks] + + # SCALES + self.scales = nn.Parameter( + torch.empty((num_out_groups, 1, 1, 1), **factory_kwargs), requires_grad=False + ) # [num_out_groups, 1, 1, 1] + + # BIAS + if bias: + self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs), requires_grad=False) + else: + self.register_parameter("bias", None) + + def transpose_codes(self): + self.codes.data = torch.permute(self.codes.data, (1, 0, 2)).contiguous() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.ops.aqlm.code2x8_lut_matmat( + input, self.codes, self.codebooks, self.scales, bias=self.bias + ) diff --git a/examples/models/llama2/aqlm/lut_kernel.cpp b/examples/models/llama2/aqlm/lut_kernel.cpp new file mode 100644 index 00000000000..8429fa60ce8 --- /dev/null +++ b/examples/models/llama2/aqlm/lut_kernel.cpp @@ -0,0 +1,114 @@ +#include "lut_kernel.h" + +#include +#include + +#include +#include +#include + +#include + +constexpr int GROUP_SIZE = 1024; + + +template +void quadruple_for( + int num_inputs, + int num_input_groups, const fp_dtype* lut, + int out_features, const uint8_t* b_alt, + fp_dtype* output_vec +) +{ + for (int input = 0; input < num_inputs; ++input) { + for (int i = 0; i < out_features; ++i) { + output_vec[input * out_features + i] = 0; + } + } + + for (int input = 0; input < num_inputs; ++input) { + for (int j = 0; j < num_input_groups; ++j) { + for (int i = 0; i < out_features; ++i) { + for (int c = 0; c < 2; ++c) { + output_vec[input * out_features + i] += lut[ + input * num_input_groups * 2 * 256 + + j * 2 * 256 + + c * codebook_size + + b_alt[ + j * 2 * out_features + + i * 2 + + c + ] + ]; + } + } + } + } +} + +namespace torch { + namespace executor { + namespace native { + Tensor& code2x8_lut_matmat_out( + RuntimeContext& ctx, + const Tensor& input, + const Tensor& codes, + const Tensor& codebooks, + const Tensor& scales, + const optional& bias, + Tensor& out + ) { + auto input_sizes = input.sizes(); + auto out_features = codes.size(1) * codebooks.size(2); + auto input_vector_size = input.size(input.dim() - 1); + auto num_input_vectors = std::accumulate(input_sizes.begin(), input_sizes.end(), 1, std::multiplies()) / input_vector_size; + + // Allocate LUT + auto lut_data = ctx.allocate_temp( + 4 * num_input_vectors * input_vector_size / 8 * codebooks.size(0) * codebooks.size(1) + ).get(); + + // A @ B.T + ::executorch::cpublas::gemm( + ::executorch::cpublas::TransposeType::Transpose, + ::executorch::cpublas::TransposeType::NoTranspose, + (int64_t)codebooks.size(0) * codebooks.size(1), // B rows + (int64_t)num_input_vectors * input_vector_size / 8, // A rows + (int64_t)8, // MatMul dim size + 1.f, + (float*)codebooks.const_data_ptr(), (int64_t)8, + (float*)input.const_data_ptr(), (int64_t)8, + 0.f, + (float*)lut_data, (int64_t)codebooks.size(0) * codebooks.size(1) + ); + + // Do lookup matmul + quadruple_for( + num_input_vectors, + input_vector_size / 8, + (const float*)lut_data, + out_features, + (const uint8_t*)codes.const_data_ptr(), + (float*)out.mutable_data_ptr() + ); + + for (int j = 0; j < out_features; ++j) { + for (int i=0; i < num_input_vectors; ++i) { + out.mutable_data_ptr()[ + i * out_features + j + ] *= scales.const_data_ptr()[j]; + if (bias.has_value()) { + out.mutable_data_ptr()[ + i * out_features + j + ] += bias.value().const_data_ptr()[j]; + } + } + } + + return out; + } + } // namespace native + } // namespace executor +} // namespace torch + +EXECUTORCH_LIBRARY(aqlm, "code2x8_lut_matmat.out", torch::executor::native::code2x8_lut_matmat_out); diff --git a/examples/models/llama2/aqlm/lut_kernel.h b/examples/models/llama2/aqlm/lut_kernel.h new file mode 100644 index 00000000000..3653c10e1e8 --- /dev/null +++ b/examples/models/llama2/aqlm/lut_kernel.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +namespace torch { +namespace executor { + +namespace native { + +Tensor& code2x8_lut_matmat_out( + RuntimeContext& ctx, + const Tensor& input, + const Tensor& codes, + const Tensor& codebooks, + const Tensor& scales, + const optional& bias, + Tensor& out +); +} // namespace native +} // namespace executor +} // namespace torch diff --git a/examples/models/llama2/aqlm/lut_kernel.py b/examples/models/llama2/aqlm/lut_kernel.py new file mode 100644 index 00000000000..1e7d2eb8748 --- /dev/null +++ b/examples/models/llama2/aqlm/lut_kernel.py @@ -0,0 +1,26 @@ +import os +from typing import Optional +import logging +from pathlib import Path + +import torch +from torch.library import impl + +try: + op = torch.ops.aqlm.code2x8_lut_matmat.default + assert op is not None +except: + libs = list(Path(__file__).parent.resolve().glob("libaqlm_aot_lib.*")) + assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" + logging.warn(f"Loading aqlm library: {libs[0]}") + torch.ops.load_library(libs[0]) + op = torch.ops.aqlm.code2x8_lut_matmat.default + assert op is not None + +aqlm_lib = torch.library.Library("aqlm", "IMPL") + +@impl(aqlm_lib, "code2x8_lut_matmat", "Meta") +def code2x8_lut_matmat_meta(input, codes, codebooks, scales, bias=None): + return torch.empty( + input.shape[:-1] + (codes.shape[1],), device=input.device, dtype=input.dtype + ) diff --git a/examples/models/llama2/aqlm/lut_kernel_pytorch.cpp b/examples/models/llama2/aqlm/lut_kernel_pytorch.cpp new file mode 100644 index 00000000000..b384113ecf7 --- /dev/null +++ b/examples/models/llama2/aqlm/lut_kernel_pytorch.cpp @@ -0,0 +1,82 @@ +#include "lut_kernel.h" + +#include +#include + +#include + +namespace torch { + namespace executor { + namespace native { + Tensor& code2x8_lut_matmat_out_no_context( + const Tensor& input, + const Tensor& codes, + const Tensor& codebooks, + const Tensor& scales, + const optional bias, + Tensor& output + ) { + void* memory_pool = malloc(10000000 * sizeof(uint8_t)); + MemoryAllocator allocator(10000000, (uint8_t*)memory_pool); + + exec_aten::RuntimeContext context{nullptr, &allocator}; + return torch::executor::native::code2x8_lut_matmat_out( + context, + input, + codes, + codebooks, + scales, + bias, + output + ); + } + + at::Tensor code2x8_lut_matmat( + const at::Tensor& input, + const at::Tensor& codes, + const at::Tensor& codebooks, + const at::Tensor& scales, + const c10::optional bias + ) { + auto sizes = input.sizes().vec(); + sizes[sizes.size() - 1] = codes.size(1) * codebooks.size(2); + auto out = at::empty(sizes, + at::TensorOptions() + .dtype(input.dtype()) + .device(input.device()) + ); + + WRAP_TO_ATEN(code2x8_lut_matmat_out_no_context, 5)( + input, + codes, + codebooks, + scales, + bias, + out + ); + return out; + } + } // namespace native + } // namespace executor +} // namespace torch + +TORCH_LIBRARY(aqlm, m) { + m.def( + "code2x8_lut_matmat(Tensor input, Tensor codes, " + "Tensor codebooks, Tensor scales, *, Tensor? bias=None) -> Tensor" + ); + m.def( + "code2x8_lut_matmat.out(Tensor input, Tensor codes, " + "Tensor codebooks, Tensor scales, *, Tensor? bias=None, Tensor(c!) out) -> Tensor(c!)" + ); +} + +TORCH_LIBRARY_IMPL(aqlm, CompositeExplicitAutograd, m) { + m.impl( + "code2x8_lut_matmat", torch::executor::native::code2x8_lut_matmat + ); + m.impl( + "code2x8_lut_matmat.out", + WRAP_TO_ATEN(torch::executor::native::code2x8_lut_matmat_out_no_context, 5) + ); +} \ No newline at end of file diff --git a/examples/models/llama2/aqlm/utils.py b/examples/models/llama2/aqlm/utils.py new file mode 100644 index 00000000000..9c81e20226c --- /dev/null +++ b/examples/models/llama2/aqlm/utils.py @@ -0,0 +1,98 @@ +import torch +from torch import nn + +from accelerate import init_empty_weights +from executorch.examples.models.llama2.aqlm.linear import Aqlm2x8Linear + + +def replace_with_aqlm_linear( + model, + quantization_config=None, + linear_weights_not_to_quantize=None, + current_key_name=None, + has_been_replaced=False, +): + """ + Public method that recursively replaces the Linear layers of the given model with AQLM quantized layers. + `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the + conversion has been successfull or not. + + Args: + model (`torch.nn.Module`): + The model to convert, can be any `torch.nn.Module` instance. + quantization_config (`AqlmConfig`): + The quantization config object that contains the quantization parameters. + linear_weights_not_to_quantize (`list[str]`, *optional*): + A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be + converted. + current_key_name (`list`, *optional*): + A list that contains the current key name. This is used for recursion and should not be passed by the user. + has_been_replaced (`bool`, *optional*): + A boolean that indicates if the conversion has been successful or not. This is used for recursion and + should not be passed by the user. + """ + if linear_weights_not_to_quantize is None: + linear_weights_not_to_quantize = [] + + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, nn.Linear): + # Check if the current key is not in the `linear_weights_not_to_quantize` + if ".".join(current_key_name) + ".weight" not in linear_weights_not_to_quantize: + # with init_empty_weights(): + in_features = module.in_features + out_features = module.out_features + + model._modules[name] = Aqlm2x8Linear( + in_features, + out_features, + bias=module.bias is not None, + in_group_size=quantization_config.in_group_size, + out_group_size=quantization_config.out_group_size, + num_codebooks=quantization_config.num_codebooks, + nbits_per_codebook=quantization_config.nbits_per_codebook, + ) + has_been_replaced = True + + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + if len(list(module.children())) > 0: + _, has_been_replaced = replace_with_aqlm_linear( + module, + quantization_config=quantization_config, + linear_weights_not_to_quantize=linear_weights_not_to_quantize, + current_key_name=current_key_name, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def transpose_codes( + model, + current_key_name=None, + has_been_transposed=False, +): + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, Aqlm2x8Linear): + model._modules[name].transpose_codes() + has_been_transposed = True + if len(list(module.children())) > 0: + _, has_been_transposed = transpose_codes( + module, + current_key_name=current_key_name, + has_been_transposed=has_been_transposed, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_transposed diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index c1fae0eb77b..78a0c891700 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -119,7 +119,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--quantization_mode", type=str, default=None, - choices=["int8", "8da4w", "8da4w-gptq"], + choices=["int8", "8da4w", "8da4w-gptq", "aqlm-2x8"], help="type of quantization", ) diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index 4830e18c94c..943ea4ae008 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -130,6 +130,32 @@ def quantize( group_size, ) model = gptq_quantizer.quantize(model, inputs) + return model + elif qmode == "aqlm-2x8": + from executorch.examples.models.llama2.aqlm.lut_kernel import aqlm_lib # noqa + from executorch.examples.models.llama2.aqlm.utils import replace_with_aqlm_linear, transpose_codes + from transformers.utils.quantization_config import AqlmConfig + + model, _ = replace_with_aqlm_linear( + model=model, + quantization_config=AqlmConfig( + num_codebooks=2, + nbits_per_codebook=8, + ), + linear_weights_not_to_quantize=[ + "tok_embeddings.weight", + "output.weight", + ], + ) + model.load_state_dict( + torch.load("/Users/blacksamorez/models/Llama-2-7b-AQLM-2Bit-2x8-hf/executorch.pth"), + strict=False, + # assign=True, + ) + model, _ = transpose_codes( + model=model, + ) + return model else: raise Exception(f"Unrecognized quantize mode: {qmode}") diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 2c9733d1dae..71e2a1e0975 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -95,6 +95,9 @@ Result> Module::method_names() { return result; } +static uint8_t temp_allocator_pool[4 * 1024U * 1024U]; // 4 MB +static MemoryAllocator temp_allocator(sizeof(temp_allocator_pool), temp_allocator_pool); + Error Module::load_method(const std::string& method_name) { if (!is_method_loaded(method_name)) { ET_CHECK_OK_OR_RETURN_ERROR(load()); @@ -118,7 +121,7 @@ Error Module::load_method(const std::string& method_name) { method_holder.planned_spans.data(), method_holder.planned_spans.size())); method_holder.memory_manager = std::make_unique( - memory_allocator_.get(), method_holder.planned_memory.get()); + memory_allocator_.get(), method_holder.planned_memory.get(), &temp_allocator); method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method( method_name.c_str(), method_holder.memory_manager.get(), diff --git a/pyproject.toml b/pyproject.toml index e5c2b33bf07..e9bba4dd76a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,9 +62,9 @@ dependencies=[ "ruamel.yaml", "sympy", "tabulate", - "torch==2.4.0", - "torchvision==0.19.0", - "torchaudio==2.4.0", + "torch>=2.4.0", + "torchvision>=0.19.0", + "torchaudio>=2.4.0", ] [project.urls] diff --git a/setup.py b/setup.py index c96bb102c5d..501c89d22d3 100644 --- a/setup.py +++ b/setup.py @@ -489,7 +489,7 @@ def run(self): "-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON", # add llama sdpa ops to pybindings. "-DEXECUTORCH_BUILD_KERNELS_CUSTOM_AOT=ON", ] - build_args += ["--target", "custom_ops_aot_lib"] + build_args += ["--target", "custom_ops_aot_lib", "--target", "aqlm_aot_lib"] # Allow adding extra cmake args through the environment. Used by some # tests and demos to expand the set of targets included in the pip # package. @@ -569,6 +569,13 @@ def get_ext_modules() -> list[Extension]: "executorch/examples/models/llama2/custom_ops", ) ) + ext_modules.append( + # Install the prebuilt library for custom ops used in llama. + BuiltFile( + "examples/models/llama2/aqlm/libaqlm_aot_lib.*", + "executorch/examples/models/llama2/aqlm", + ) + ) # Note that setuptools uses the presence of ext_modules as the main signal # that a wheel is platform-specific. If we install any platform-specific From 7c3e191a6fdac8cfcff02eb590922898ba76f9e4 Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Sun, 4 Aug 2024 17:51:23 +0300 Subject: [PATCH 02/15] faster lut --- examples/models/llama2/aqlm/CmakeLists.txt | 10 ++++++- examples/models/llama2/aqlm/lut_kernel.cpp | 33 ++++++++-------------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/examples/models/llama2/aqlm/CmakeLists.txt b/examples/models/llama2/aqlm/CmakeLists.txt index 1af90c0b757..3645a70860b 100644 --- a/examples/models/llama2/aqlm/CmakeLists.txt +++ b/examples/models/llama2/aqlm/CmakeLists.txt @@ -66,8 +66,16 @@ else() list(APPEND aqlm_libs xnnpack_backend) endif() +find_package(OpenMP REQUIRED) +# list(APPEND aqlm_libs OpenMP::OpenMP_CXX) +list(APPEND aqlm_libs omp) + add_library(aqlm ${_aqlm__srcs}) +# Enable optimization +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fopenmp") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fopenmp") + target_include_directories(aqlm PUBLIC "${_common_include_directories}") target_include_directories( aqlm PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../../include" @@ -75,7 +83,7 @@ target_include_directories( target_link_libraries(aqlm PUBLIC ${aqlm_libs}) target_compile_options( - aqlm PUBLIC ${_common_compile_options} -DET_USE_THREADPOOL + aqlm PUBLIC ${_common_compile_options} -DET_USE_THREADPOOL ${OpenMP_CXX_FLAGS} ) install(TARGETS aqlm DESTINATION lib) diff --git a/examples/models/llama2/aqlm/lut_kernel.cpp b/examples/models/llama2/aqlm/lut_kernel.cpp index 8429fa60ce8..a26230e624a 100644 --- a/examples/models/llama2/aqlm/lut_kernel.cpp +++ b/examples/models/llama2/aqlm/lut_kernel.cpp @@ -9,38 +9,29 @@ #include -constexpr int GROUP_SIZE = 1024; template void quadruple_for( int num_inputs, - int num_input_groups, const fp_dtype* lut, - int out_features, const uint8_t* b_alt, - fp_dtype* output_vec + int num_input_groups, const fp_dtype* __restrict__ lut, + int out_features, const uint8_t* __restrict__ b_alt, + fp_dtype* __restrict__ output_vec ) { - for (int input = 0; input < num_inputs; ++input) { - for (int i = 0; i < out_features; ++i) { - output_vec[input * out_features + i] = 0; - } - } + std::memset(output_vec, 0, num_inputs * out_features * sizeof(fp_dtype)); + + const int lut_stride = num_input_groups * 2 * 256; + const int b_alt_stride = 2 * out_features; for (int input = 0; input < num_inputs; ++input) { for (int j = 0; j < num_input_groups; ++j) { + const fp_dtype* lut_ptr = lut + input * lut_stride + j * 2 * 256; + const uint8_t* b_alt_ptr = b_alt + j * b_alt_stride; + for (int i = 0; i < out_features; ++i) { - for (int c = 0; c < 2; ++c) { - output_vec[input * out_features + i] += lut[ - input * num_input_groups * 2 * 256 + - j * 2 * 256 + - c * codebook_size + - b_alt[ - j * 2 * out_features + - i * 2 + - c - ] - ]; - } + output_vec[input * out_features + i] += lut_ptr[b_alt_ptr[i * 2]]; + output_vec[input * out_features + i] += lut_ptr[256 + b_alt_ptr[i * 2 + 1]]; } } } From db94abad36dd09f3f531b1080760bc0d40f35224 Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Mon, 5 Aug 2024 19:51:03 +0300 Subject: [PATCH 03/15] Separate scales fn --- examples/models/llama2/aqlm/CmakeLists.txt | 10 ++--- examples/models/llama2/aqlm/lut_kernel.cpp | 46 ++++++++++++++++------ 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/examples/models/llama2/aqlm/CmakeLists.txt b/examples/models/llama2/aqlm/CmakeLists.txt index 3645a70860b..2da4e01aa5a 100644 --- a/examples/models/llama2/aqlm/CmakeLists.txt +++ b/examples/models/llama2/aqlm/CmakeLists.txt @@ -66,15 +66,15 @@ else() list(APPEND aqlm_libs xnnpack_backend) endif() -find_package(OpenMP REQUIRED) +# find_package(OpenMP REQUIRED) # list(APPEND aqlm_libs OpenMP::OpenMP_CXX) -list(APPEND aqlm_libs omp) +# list(APPEND aqlm_libs omp) add_library(aqlm ${_aqlm__srcs}) # Enable optimization -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fopenmp") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fopenmp") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") # -O3 -fopenmp +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") # -O3 -fopenmp target_include_directories(aqlm PUBLIC "${_common_include_directories}") target_include_directories( @@ -83,7 +83,7 @@ target_include_directories( target_link_libraries(aqlm PUBLIC ${aqlm_libs}) target_compile_options( - aqlm PUBLIC ${_common_compile_options} -DET_USE_THREADPOOL ${OpenMP_CXX_FLAGS} + aqlm PUBLIC ${_common_compile_options} -DET_USE_THREADPOOL # ${OpenMP_CXX_FLAGS} ) install(TARGETS aqlm DESTINATION lib) diff --git a/examples/models/llama2/aqlm/lut_kernel.cpp b/examples/models/llama2/aqlm/lut_kernel.cpp index a26230e624a..6fd747aae43 100644 --- a/examples/models/llama2/aqlm/lut_kernel.cpp +++ b/examples/models/llama2/aqlm/lut_kernel.cpp @@ -10,7 +10,6 @@ #include - template void quadruple_for( int num_inputs, @@ -37,6 +36,27 @@ void quadruple_for( } } + +void row_wise_scaling_and_bias( + float* __restrict__ out, + const float* __restrict__ scales, const float* __restrict__ bias, + int num_input_vectors, int out_features +) { + for (int j = 0; j < out_features; ++j) { + float scale_value = scales[j]; + float bias_value; + if (bias != nullptr){ + bias_value = bias[j]; + } + for (int i=0; i < num_input_vectors; ++i) { + out[i * out_features + j] *= scale_value; + if (bias != nullptr) { + out[i * out_features + j] += bias_value; + } + } + } +} + namespace torch { namespace executor { namespace native { @@ -82,19 +102,19 @@ namespace torch { (const uint8_t*)codes.const_data_ptr(), (float*)out.mutable_data_ptr() ); - - for (int j = 0; j < out_features; ++j) { - for (int i=0; i < num_input_vectors; ++i) { - out.mutable_data_ptr()[ - i * out_features + j - ] *= scales.const_data_ptr()[j]; - if (bias.has_value()) { - out.mutable_data_ptr()[ - i * out_features + j - ] += bias.value().const_data_ptr()[j]; - } - } + + const float* bias_ptr = nullptr; + if (bias.has_value()) { + bias_ptr = bias.value().const_data_ptr(); } + + row_wise_scaling_and_bias( + out.mutable_data_ptr(), + scales.const_data_ptr(), + bias_ptr, + num_input_vectors, + out_features + ); return out; } From 97c86e3ca4c6ee2cc22490406fbfaec754b947de Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Wed, 7 Aug 2024 13:53:26 +0300 Subject: [PATCH 04/15] link aqlm to android lib --- examples/models/llama2/source_transformation/quantize.py | 2 +- extension/android/CMakeLists.txt | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index 943ea4ae008..fc410c914a1 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -148,7 +148,7 @@ def quantize( ], ) model.load_state_dict( - torch.load("/Users/blacksamorez/models/Llama-2-7b-AQLM-2Bit-2x8-hf/executorch.pth"), + torch.load("/Users/blacksamorez/models/Llama-2-7b-AQLM-PV-2Bit-2x8-hf/executorch.pth"), strict=False, # assign=True, ) diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index b8ae82bbe2b..fcc2531b19f 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -86,6 +86,13 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) set_property(TARGET custom_ops PROPERTY IMPORTED_LOCATION ${CUSTOM_OPS_PATH}) target_link_options_shared_lib(custom_ops) + set(AQLM_PATH + ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama2/aqlm/libaqlm.a + ) + add_library(aqlm STATIC IMPORTED) + set_property(TARGET aqlm PROPERTY IMPORTED_LOCATION ${AQLM_PATH}) + target_link_options_shared_lib(aqlm) + target_link_options_shared_lib(quantized_ops_lib) if(TARGET pthreadpool) @@ -117,6 +124,7 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) ${link_libraries} llama_runner custom_ops + aqlm cpublas eigen_blas quantized_kernels From 008f9c33a4efeb2e3e0c32c47737fa3dec187078 Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Wed, 7 Aug 2024 13:58:41 +0300 Subject: [PATCH 05/15] row_wise_scaling_and_bias for neon --- examples/models/llama2/aqlm/lut_kernel.cpp | 26 +++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/examples/models/llama2/aqlm/lut_kernel.cpp b/examples/models/llama2/aqlm/lut_kernel.cpp index 6fd747aae43..f928c13c1cf 100644 --- a/examples/models/llama2/aqlm/lut_kernel.cpp +++ b/examples/models/llama2/aqlm/lut_kernel.cpp @@ -36,7 +36,30 @@ void quadruple_for( } } - +#if defined(__aarch64__) && defined(__ARM_NEON) +#include +void row_wise_scaling_and_bias( + float* __restrict__ out, + const float* __restrict__ scales, const float* __restrict__ bias, + int num_input_vectors, int out_features +) { + for (int j = 0; j < out_features; j += 4) { + float32x4_t scales_vec = vld1q_f32(scales + j); + float32x4_t bias_vec; + if (bias != nullptr){ + bias_vec = vld1q_f32(bias + j); + } + for (int i=0; i < num_input_vectors; ++i) { + float32x4_t values_vec = vld1q_f32(out + i * out_features + j); + values_vec = vmulq_f32(values_vec, scales_vec); + if (bias != nullptr) { + values_vec = vaddq_f32(values_vec, bias_vec); + } + vst1q_f32(out + i * out_features + j, values_vec); + } + } +} +#else void row_wise_scaling_and_bias( float* __restrict__ out, const float* __restrict__ scales, const float* __restrict__ bias, @@ -56,6 +79,7 @@ void row_wise_scaling_and_bias( } } } +#endif namespace torch { namespace executor { From 0645fee5f9893671e7460091f17650848e9341ee Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Wed, 7 Aug 2024 13:58:54 +0300 Subject: [PATCH 06/15] gradle update --- examples/demo-apps/android/LlamaDemo/build.gradle.kts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/demo-apps/android/LlamaDemo/build.gradle.kts b/examples/demo-apps/android/LlamaDemo/build.gradle.kts index 568efa2815b..bf4572c2ab8 100644 --- a/examples/demo-apps/android/LlamaDemo/build.gradle.kts +++ b/examples/demo-apps/android/LlamaDemo/build.gradle.kts @@ -8,6 +8,6 @@ // Top-level build file where you can add configuration options common to all sub-projects/modules. plugins { - id("com.android.application") version "8.1.0" apply false + id("com.android.application") version "8.1.4" apply false id("org.jetbrains.kotlin.android") version "1.8.10" apply false } From 43896016652b49291f808af584ed897b6f7172e2 Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Wed, 7 Aug 2024 14:08:37 +0300 Subject: [PATCH 07/15] revert gradle --- examples/demo-apps/android/LlamaDemo/build.gradle.kts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/demo-apps/android/LlamaDemo/build.gradle.kts b/examples/demo-apps/android/LlamaDemo/build.gradle.kts index bf4572c2ab8..568efa2815b 100644 --- a/examples/demo-apps/android/LlamaDemo/build.gradle.kts +++ b/examples/demo-apps/android/LlamaDemo/build.gradle.kts @@ -8,6 +8,6 @@ // Top-level build file where you can add configuration options common to all sub-projects/modules. plugins { - id("com.android.application") version "8.1.4" apply false + id("com.android.application") version "8.1.0" apply false id("org.jetbrains.kotlin.android") version "1.8.10" apply false } From cade4c385d37e77f7f1e5a78514e050e297f5e7a Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Wed, 7 Aug 2024 14:18:26 +0300 Subject: [PATCH 08/15] libs linking fix --- CMakeLists.txt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a4814f98ba5..f1031e74773 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -636,7 +636,7 @@ if(EXECUTORCH_BUILD_PYBIND) # TODO(larryliu): Fix macOS 2 dylibs having 2 sets of static variables issue if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT AND NOT APPLE) list(APPEND _dep_libs custom_ops_aot_lib) - list(APPEND _dep_libs aqlm) + list(APPEND _dep_libs aqlm_aot_lib) endif() # TODO(laryliu): Fix linux duplicate registation problem. In GH CI worker # libcustom_ops.a doesn't dedup with the one indirectly linked from @@ -644,6 +644,7 @@ if(EXECUTORCH_BUILD_PYBIND) if(EXECUTORCH_BUILD_KERNELS_CUSTOM AND APPLE) target_link_options_shared_lib(custom_ops) list(APPEND _dep_libs custom_ops) + target_link_options_shared_lib(aqlm) list(APPEND _dep_libs aqlm) endif() # compile options for pybind @@ -704,7 +705,7 @@ if(EXECUTORCH_BUILD_PYBIND) PROPERTIES # Assume that this library will be installed in # `site-packages/executorch/extension/pybindings`, and that # the custom_ops_aot_lib should be found with relative path. - BUILD_RPATH "$ORIGIN:$ORIGIN/../../examples/models/llama2/custom_ops" + BUILD_RPATH "$ORIGIN:$ORIGIN/../../examples/models/llama2/custom_ops:$ORIGIN/../../examples/models/llama2/aqlm" ) endif() From 3ad5904a00564995ad3fd92b36fb7753f687be16 Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Wed, 7 Aug 2024 14:42:28 +0300 Subject: [PATCH 09/15] removed unnecessary includes --- examples/models/llama2/aqlm/lut_kernel.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/models/llama2/aqlm/lut_kernel.cpp b/examples/models/llama2/aqlm/lut_kernel.cpp index f928c13c1cf..17af6ebb92c 100644 --- a/examples/models/llama2/aqlm/lut_kernel.cpp +++ b/examples/models/llama2/aqlm/lut_kernel.cpp @@ -4,9 +4,6 @@ #include #include -#include -#include - #include From 94d45447452ef409b53ed46e924fe6d292948d9d Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Wed, 7 Aug 2024 15:02:47 +0300 Subject: [PATCH 10/15] more temp buffers --- extension/module/module.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 71e2a1e0975..e972e191fba 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -95,7 +95,7 @@ Result> Module::method_names() { return result; } -static uint8_t temp_allocator_pool[4 * 1024U * 1024U]; // 4 MB +static uint8_t temp_allocator_pool[16 * 1024U * 1024U]; // 16 MB static MemoryAllocator temp_allocator(sizeof(temp_allocator_pool), temp_allocator_pool); Error Module::load_method(const std::string& method_name) { From 3156fc2058860a319504cf60362c925eb0b8c4a9 Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Wed, 7 Aug 2024 19:58:08 +0300 Subject: [PATCH 11/15] Cleaner export --- .../models/llama2/aqlm/convert_from_hf.ipynb | 79 +++++++++++++++++++ examples/models/llama2/export_llama_lib.py | 6 ++ .../llama2/source_transformation/quantize.py | 5 +- 3 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 examples/models/llama2/aqlm/convert_from_hf.ipynb diff --git a/examples/models/llama2/aqlm/convert_from_hf.ipynb b/examples/models/llama2/aqlm/convert_from_hf.ipynb new file mode 100644 index 00000000000..00707d637b6 --- /dev/null +++ b/examples/models/llama2/aqlm/convert_from_hf.ipynb @@ -0,0 +1,79 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from safetensors.torch import load_file\n", + "\n", + "LOAD_PATH = \"/Users/blacksamorez/models/Llama-2-7b-AQLM-PV-2Bit-2x8-hf/model.safetensors\"\n", + "SAVE_PATH = \"/Users/blacksamorez/models/Llama-2-7b-AQLM-PV-2Bit-2x8-hf/executorch.pth\"\n", + "\n", + "dict = load_file(LOAD_PATH)\n", + "\n", + "mapping = {\n", + " \"model.\": \"\",\n", + " \n", + " \"self_attn.q_proj\": \"attention.wq\",\n", + " \"self_attn.k_proj\": \"attention.wk\",\n", + " \"self_attn.v_proj\": \"attention.wv\",\n", + " \"self_attn.o_proj\": \"attention.wo\",\n", + " \n", + " \"mlp.up_proj\": \"feed_forward.w3\",\n", + " \"mlp.gate_proj\": \"feed_forward.w1\",\n", + " \"mlp.down_proj\": \"feed_forward.w2\",\n", + " \n", + " \"input_layernorm\": \"attention_norm\",\n", + " \"post_attention_layernorm\": \"ffn_norm\",\n", + " \n", + " \"lm_head\": \"output\",\n", + " \"embed_tokens\": \"tok_embeddings\",\n", + "}\n", + "\n", + "\n", + "new_dict = {}\n", + "\n", + "for key, value in dict.items():\n", + " for old, new in mapping.items():\n", + " key = key.replace(old, new)\n", + " \n", + " if \"attention.wq.codes\" in key or \"attention.wk.codes\" in key:\n", + " # [num_out_groups, num_in_groups, num_codebooks]\n", + " print(f\"Transposing codes {key} {value.shape=}\")\n", + " value = (value.reshape(32, 2, 128 // 2, -1, 2)\n", + " .transpose(1, 2)\n", + " .reshape(128 * 32, -1, 2))\n", + " \n", + " if \"attention.wq.scales\" in key or \"attention.wk.scales\" in key:\n", + " # [num_out_groups, 1, 1, 1]\n", + " print(f\"Transposing scales {key} {value.shape=}\")\n", + " value = (value.reshape(32, 2, 128 // 2, 1)\n", + " .transpose(1, 2)\n", + " .reshape(128 * 32, 1, 1, 1))\n", + " \n", + " new_dict[key] = value\n", + " \n", + "del new_dict[\"output.weight\"]\n", + "del new_dict[\"tok_embeddings.weight\"]\n", + "\n", + "torch.save(new_dict, SAVE_PATH)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "executorch", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 78a0c891700..724f7f68610 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -212,6 +212,12 @@ def build_args_parser() -> argparse.ArgumentParser: default=None, help="group_size for weight quantization", ) + parser.add_argument( + "--converted_aqlm_checkpoint_path", + type=str, + default=None, + help="Path to .pt file produced with aqlm/convert_from_hf.ipynb", + ) parser.add_argument( "-d", diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index fc410c914a1..3561d6e1de9 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -47,6 +47,8 @@ def quantize( blocksize: int = 128, tokenizer_path: Optional[Path] = None, verbose: bool = False, + # following arguments are only used for AQLM + converted_aqlm_checkpoint_path: Optional[Path] = None, ) -> torch.nn.Module: """ Quantizes a model by converting all weights to int8. @@ -148,7 +150,7 @@ def quantize( ], ) model.load_state_dict( - torch.load("/Users/blacksamorez/models/Llama-2-7b-AQLM-PV-2Bit-2x8-hf/executorch.pth"), + torch.load(converted_aqlm_checkpoint_path), strict=False, # assign=True, ) @@ -601,6 +603,7 @@ def get_quant_weight_transform(args, dtype_override, verbose): "calibration_tasks", "calibration_limit", "calibration_seq_length", + "converted_aqlm_checkpoint_path", ] arg_dict = vars(args) quant_args = { From 91b13785038ae3b5e5c849767129ce723a0d9351 Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Fri, 9 Aug 2024 01:06:16 +0300 Subject: [PATCH 12/15] New export --- .../models/llama2/aqlm/convert_from_hf.ipynb | 819 +++++++++++++++++- examples/models/llama2/aqlm/linear.py | 7 +- examples/models/llama2/aqlm/utils.py | 24 - examples/models/llama2/builder.py | 1 + examples/models/llama2/model.py | 3 + .../llama2/source_transformation/quantize.py | 7 +- 6 files changed, 812 insertions(+), 49 deletions(-) diff --git a/examples/models/llama2/aqlm/convert_from_hf.ipynb b/examples/models/llama2/aqlm/convert_from_hf.ipynb index 00707d637b6..051bf9d8ecc 100644 --- a/examples/models/llama2/aqlm/convert_from_hf.ipynb +++ b/examples/models/llama2/aqlm/convert_from_hf.ipynb @@ -2,17 +2,787 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "output.weight: value.shape=torch.Size([128256, 4096])\n", + "tok_embeddings.weight: value.shape=torch.Size([128256, 4096])\n", + "layers.0.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.0.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.0.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.0.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.0.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.0.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.0.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.0.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.0.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.0.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.0.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.0.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.0.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.0.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.0.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.0.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.0.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.0.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.0.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.0.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.0.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.0.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.0.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.1.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.1.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.1.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.1.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.1.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.1.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.1.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.1.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.1.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.1.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.1.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.1.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.1.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.1.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.1.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.1.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.1.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.1.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.1.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.1.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.1.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.1.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.1.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.10.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.10.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.10.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.10.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.10.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.10.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.10.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.10.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.10.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.10.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.10.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.10.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.10.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.10.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.10.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.10.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.10.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.10.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.10.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.10.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.10.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.10.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.10.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.11.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.11.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.11.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.11.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.11.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.11.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.11.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.11.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.11.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.11.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.11.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.11.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.11.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.11.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.11.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.11.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.11.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.11.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.11.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.11.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.11.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.11.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.11.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.12.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.12.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.12.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.12.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.12.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.12.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.12.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.12.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.12.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.12.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.12.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.12.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.12.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.12.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.12.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.12.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.12.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.12.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.12.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.12.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.12.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.12.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.12.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.13.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.13.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.13.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.13.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.13.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.13.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.13.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.13.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.13.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.13.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.13.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.13.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.13.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.13.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.13.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.13.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.13.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.13.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.13.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.13.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.13.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.13.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.13.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.14.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.14.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.14.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.14.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.14.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.14.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.14.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.14.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.14.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.14.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.14.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.14.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.14.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.14.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.14.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.14.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.14.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.14.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.14.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.14.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.14.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.14.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.14.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.15.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.15.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.15.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.15.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.15.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.15.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.15.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.15.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.15.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.15.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.15.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.15.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.15.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.15.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.15.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.15.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.15.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.15.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.15.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.15.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.15.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.15.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.15.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.16.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.16.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.16.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.16.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.16.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.16.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.16.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.16.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.16.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.16.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.16.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.16.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.16.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.16.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.16.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.16.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.16.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.16.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.16.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.16.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.16.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.16.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.16.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.17.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.17.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.17.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.17.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.17.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.17.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.17.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.17.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.17.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.17.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.17.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.17.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.17.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.17.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.17.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.17.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.17.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.17.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.17.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.17.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.17.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.17.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.17.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.18.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.18.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.18.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.18.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.18.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.18.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.18.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.18.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.18.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.18.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.18.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.18.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.18.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.18.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.18.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.18.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.18.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.18.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.18.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.18.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.18.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.18.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.18.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.19.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.19.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.19.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.19.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.19.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.19.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.19.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.19.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.19.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.19.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.19.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.19.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.19.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.19.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.19.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.19.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.19.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.19.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.19.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.19.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.19.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.19.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.19.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.2.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.2.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.2.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.2.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.2.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.2.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.2.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.2.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.2.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.2.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.2.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.2.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.2.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.2.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.2.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.2.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.2.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.2.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.2.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.2.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.2.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.2.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.2.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.20.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.20.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.20.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.20.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.20.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.20.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.20.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.20.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.20.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.20.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.20.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.20.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.20.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.20.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.20.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.20.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.20.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.20.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.20.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.20.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.20.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.20.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.20.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.21.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.21.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.21.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.21.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.21.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.21.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.21.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.21.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.21.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.21.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.21.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.21.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.21.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.21.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.21.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.21.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.21.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.21.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.21.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.21.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.21.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.21.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.21.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.22.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.22.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.22.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.22.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.22.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.22.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.22.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.22.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.22.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.22.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.22.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.22.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.22.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.22.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.22.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.22.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.22.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.22.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.22.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.22.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.22.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.22.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.22.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.23.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.23.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.23.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.23.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.23.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.23.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.23.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.23.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.23.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.23.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.23.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.23.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.23.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.23.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.23.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.23.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.23.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.23.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.23.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.23.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.23.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.23.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.23.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.24.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.24.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.24.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.24.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.24.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.24.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.24.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.24.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.24.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.24.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.24.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.24.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.24.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.24.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.24.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.24.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.24.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.24.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.24.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.24.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.24.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.24.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.24.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.25.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.25.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.25.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.25.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.25.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.25.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.25.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.25.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.25.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.25.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.25.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.25.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.25.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.25.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.25.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.25.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.25.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.25.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.25.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.25.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.25.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.25.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.25.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.26.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.26.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.26.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.26.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.26.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.26.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.26.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.26.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.26.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.26.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.26.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.26.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.26.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.26.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.26.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.26.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.26.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.26.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.26.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.26.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.26.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.26.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.26.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.27.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.27.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.27.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.27.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.27.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.27.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.27.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.27.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.27.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.27.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.27.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.27.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.27.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.27.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.27.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.27.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.27.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.27.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.27.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.27.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.27.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.27.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.27.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.28.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.28.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.28.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.28.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.28.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.28.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.28.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.28.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.28.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.28.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.28.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.28.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.28.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.28.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.28.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.28.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.28.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.28.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.28.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.28.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.28.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.28.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.28.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.29.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.29.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.29.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.29.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.29.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.29.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.29.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.29.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.29.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.29.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.29.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.29.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.29.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.29.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.29.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.29.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.29.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.29.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.29.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.29.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.29.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.29.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.29.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.3.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.3.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.3.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.3.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.3.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.3.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.3.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.3.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.3.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.3.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.3.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.3.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.3.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.3.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.3.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.3.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.3.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.3.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.3.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.3.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.3.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.3.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.3.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.30.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.30.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.30.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.30.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.30.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.30.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.30.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.30.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.30.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.30.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.30.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.30.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.30.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.30.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.30.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.30.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.30.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.30.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.30.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.30.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.30.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.30.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.30.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.31.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.31.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.31.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.31.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.31.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.31.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.31.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.31.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.31.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.31.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.31.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.31.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.31.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.31.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.31.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.31.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.31.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.31.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.31.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.31.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.31.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.31.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.31.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.4.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.4.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.4.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.4.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.4.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.4.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.4.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.4.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.4.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.4.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.4.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.4.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.4.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.4.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.4.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.4.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.4.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.4.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.4.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.4.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.4.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.4.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.4.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.5.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.5.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.5.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.5.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.5.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.5.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.5.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.5.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.5.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.5.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.5.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.5.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.5.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.5.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.5.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.5.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.5.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.5.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.5.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.5.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.5.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.5.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.5.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.6.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.6.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.6.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.6.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.6.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.6.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.6.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.6.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.6.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.6.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.6.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.6.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.6.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.6.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.6.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.6.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.6.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.6.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.6.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.6.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.6.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.6.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.6.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.7.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.7.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.7.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.7.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.7.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.7.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.7.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.7.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.7.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.7.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.7.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.7.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.7.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.7.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.7.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.7.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.7.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.7.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.7.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.7.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.7.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.7.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.7.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.8.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.8.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.8.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.8.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.8.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.8.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.8.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.8.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.8.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.8.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.8.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.8.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.8.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.8.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.8.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.8.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.8.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.8.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.8.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.8.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.8.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.8.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.8.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.9.attention_norm.weight: value.shape=torch.Size([4096])\n", + "layers.9.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.9.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])\n", + "layers.9.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.9.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.9.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.9.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.9.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.9.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])\n", + "layers.9.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])\n", + "layers.9.ffn_norm.weight: value.shape=torch.Size([4096])\n", + "layers.9.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.9.attention.wk.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.9.attention.wk.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "layers.9.attention.wo.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.9.attention.wo.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.9.attention.wo.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.9.attention.wq.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.9.attention.wq.codes: value.shape=torch.Size([512, 4096, 2])\n", + "layers.9.attention.wq.scales: value.shape=torch.Size([4096, 1, 1, 1])\n", + "layers.9.attention.wv.codebooks: value.shape=torch.Size([2, 256, 1, 8])\n", + "layers.9.attention.wv.codes: value.shape=torch.Size([512, 1024, 2])\n", + "layers.9.attention.wv.scales: value.shape=torch.Size([1024, 1, 1, 1])\n", + "norm.weight: value.shape=torch.Size([4096])\n" + ] + } + ], "source": [ + "import json\n", + "\n", "import torch\n", "from safetensors.torch import load_file\n", "\n", - "LOAD_PATH = \"/Users/blacksamorez/models/Llama-2-7b-AQLM-PV-2Bit-2x8-hf/model.safetensors\"\n", - "SAVE_PATH = \"/Users/blacksamorez/models/Llama-2-7b-AQLM-PV-2Bit-2x8-hf/executorch.pth\"\n", + "LOAD_PATH = \"/Users/blacksamorez/models/Meta-Llama-3.1-8B-AQLM-2Bit-2x8-NO-TUNE/\"\n", + "SAVE_PATH = \"/Users/blacksamorez/models/Meta-Llama-3.1-8B-AQLM-2Bit-2x8-NO-TUNE/\"\n", "\n", - "dict = load_file(LOAD_PATH)\n", + "\n", + "with open(LOAD_PATH + \"config.json\", \"r\") as file:\n", + " hf_config = json.load(file)\n", + " \n", + "with open(LOAD_PATH + \"params.json\", \"w\") as file:\n", + " json.dump(\n", + " {\n", + " \"dim\": hf_config[\"hidden_size\"],\n", + " \"multiple_of\": 256,\n", + " \"n_heads\": hf_config[\"num_attention_heads\"],\n", + " \"n_kv_heads\": hf_config[\"num_key_value_heads\"],\n", + " \"n_layers\": hf_config[\"num_hidden_layers\"],\n", + " \"norm_eps\": hf_config[\"rms_norm_eps\"],\n", + " \"vocab_size\": hf_config[\"vocab_size\"],\n", + " \"ffn_dim_multiplier\": hf_config['intermediate_size'] / (8 / 3 * hf_config['hidden_size']),\n", + " },\n", + " file,\n", + " )\n", + "\n", + "hidden_size = hf_config[\"hidden_size\"]\n", + "head_dim = hf_config[\"hidden_size\"] // hf_config[\"num_attention_heads\"]\n", + "\n", + "dict = load_file(LOAD_PATH + \"model.safetensors\")\n", "\n", "mapping = {\n", " \"model.\": \"\",\n", @@ -36,31 +806,42 @@ "\n", "new_dict = {}\n", "\n", + "\n", "for key, value in dict.items():\n", " for old, new in mapping.items():\n", " key = key.replace(old, new)\n", " \n", " if \"attention.wq.codes\" in key or \"attention.wk.codes\" in key:\n", " # [num_out_groups, num_in_groups, num_codebooks]\n", - " print(f\"Transposing codes {key} {value.shape=}\")\n", - " value = (value.reshape(32, 2, 128 // 2, -1, 2)\n", + " value = (value.reshape(-1, 2, head_dim // 2, hidden_size // 8, 2)\n", " .transpose(1, 2)\n", - " .reshape(128 * 32, -1, 2))\n", + " .reshape(-1, hidden_size // 8, 2))\n", + " \n", " \n", " if \"attention.wq.scales\" in key or \"attention.wk.scales\" in key:\n", " # [num_out_groups, 1, 1, 1]\n", - " print(f\"Transposing scales {key} {value.shape=}\")\n", - " value = (value.reshape(32, 2, 128 // 2, 1)\n", + " value = (value.reshape(-1, 2, head_dim // 2, 1)\n", " .transpose(1, 2)\n", - " .reshape(128 * 32, 1, 1, 1))\n", + " .reshape(-1, 1, 1, 1))\n", + " \n", + " if \"codes\" in key:\n", + " value = value.transpose(0, 1) # <- Special memory layout for lut kernels\n", " \n", - " new_dict[key] = value\n", + " if value.dtype == torch.float16:\n", + " value = value.float()\n", " \n", - "del new_dict[\"output.weight\"]\n", - "del new_dict[\"tok_embeddings.weight\"]\n", + " print(f\"{key}: {value.shape=}\")\n", + " new_dict[key] = value\n", "\n", - "torch.save(new_dict, SAVE_PATH)" + "torch.save(new_dict, SAVE_PATH + \"model.pth\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -70,7 +851,15 @@ "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", "version": "3.10.14" } }, diff --git a/examples/models/llama2/aqlm/linear.py b/examples/models/llama2/aqlm/linear.py index 25c45f005be..66f290ea157 100644 --- a/examples/models/llama2/aqlm/linear.py +++ b/examples/models/llama2/aqlm/linear.py @@ -36,12 +36,12 @@ def __init__( ) # [num_codebooks, codebook_size, out_group_size, in_group_size] self.codes = nn.Parameter( torch.empty( - (num_out_groups, num_in_groups, num_codebooks), + (num_in_groups, num_out_groups, num_codebooks), device=device, dtype=torch.int8, ), requires_grad=False, - ) # [num_out_groups, num_in_groups, num_codebooks] + ) # [num_in_groups, num_out_groups, num_codebooks] # SCALES self.scales = nn.Parameter( @@ -53,9 +53,6 @@ def __init__( self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs), requires_grad=False) else: self.register_parameter("bias", None) - - def transpose_codes(self): - self.codes.data = torch.permute(self.codes.data, (1, 0, 2)).contiguous() def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.ops.aqlm.code2x8_lut_matmat( diff --git a/examples/models/llama2/aqlm/utils.py b/examples/models/llama2/aqlm/utils.py index 9c81e20226c..596001b3dfc 100644 --- a/examples/models/llama2/aqlm/utils.py +++ b/examples/models/llama2/aqlm/utils.py @@ -72,27 +72,3 @@ def replace_with_aqlm_linear( # Remove the last key for recursion current_key_name.pop(-1) return model, has_been_replaced - - -def transpose_codes( - model, - current_key_name=None, - has_been_transposed=False, -): - for name, module in model.named_children(): - if current_key_name is None: - current_key_name = [] - current_key_name.append(name) - - if isinstance(module, Aqlm2x8Linear): - model._modules[name].transpose_codes() - has_been_transposed = True - if len(list(module.children())) > 0: - _, has_been_transposed = transpose_codes( - module, - current_key_name=current_key_name, - has_been_transposed=has_been_transposed, - ) - # Remove the last key for recursion - current_key_name.pop(-1) - return model, has_been_transposed diff --git a/examples/models/llama2/builder.py b/examples/models/llama2/builder.py index c8d949eb6f2..2a80f2c623e 100644 --- a/examples/models/llama2/builder.py +++ b/examples/models/llama2/builder.py @@ -101,6 +101,7 @@ def load_llama_model( use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, fairseq2=weight_type == WeightType.FAIRSEQ2, max_seq_len=max_seq_len, + skip_loading="AQLM" in params_path, ) state_dict = model.state_dict() dtype = state_dict[next(iter(state_dict))].dtype diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index aa997aa56ea..e2484609e64 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -173,6 +173,9 @@ def __init__(self, **kwargs): # They possess all other metadata a tensor carries such as size, stride, requires_grad. with torch.device("meta"): self.model_ = Transformer(model_args) + + if kwargs.get("skip_loading", None): + return if "int8" in str(checkpoint_path): print("Using int8 weight-only quantization!") diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index 3561d6e1de9..a399e626a87 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -135,7 +135,7 @@ def quantize( return model elif qmode == "aqlm-2x8": from executorch.examples.models.llama2.aqlm.lut_kernel import aqlm_lib # noqa - from executorch.examples.models.llama2.aqlm.utils import replace_with_aqlm_linear, transpose_codes + from executorch.examples.models.llama2.aqlm.utils import replace_with_aqlm_linear from transformers.utils.quantization_config import AqlmConfig model, _ = replace_with_aqlm_linear( @@ -152,10 +152,7 @@ def quantize( model.load_state_dict( torch.load(converted_aqlm_checkpoint_path), strict=False, - # assign=True, - ) - model, _ = transpose_codes( - model=model, + assign=True, ) return model From 8d1053b7a670b0367ff9ea94ea4989ab79ceaf08 Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Fri, 9 Aug 2024 16:01:41 +0300 Subject: [PATCH 13/15] =?UTF-8?q?=D1=81=D1=80=D1=84=D0=B5=20=D0=B7=D0=BA?= =?UTF-8?q?=D1=89=D1=8C=D0=B7=D0=B5=20=D1=84=D1=82=D0=B2=20=D1=8B=D0=BB?= =?UTF-8?q?=D1=88=D0=B7=20=D0=B4=D1=89=D1=84=D0=B2=D1=88=D1=82=D0=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../com/example/executorchllamademo/MainActivity.java | 5 ++++- examples/models/llama2/aqlm/convert_from_hf.ipynb | 6 +++--- examples/models/llama2/aqlm/lut_kernel_pytorch.cpp | 8 ++++---- examples/models/llama2/builder.py | 3 ++- examples/models/llama2/export_llama_lib.py | 1 + extension/android/jni/jni_layer_llama.cpp | 2 +- 6 files changed, 15 insertions(+), 10 deletions(-) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index 2c94c242ed6..aff6d53f7f9 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -202,6 +202,9 @@ private void onModelRunStopped() { mSendButton.setOnClickListener( view -> { String prompt = mEditTextMessage.getText().toString(); + + String chat_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" + prompt + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"; + mMessageAdapter.add(new Message(prompt, true)); mMessageAdapter.notifyDataSetChanged(); mEditTextMessage.setText(""); @@ -219,7 +222,7 @@ public void run() { } }); - mModule.generate(prompt, MainActivity.this); + mModule.generate(chat_prompt, MainActivity.this); runOnUiThread( new Runnable() { diff --git a/examples/models/llama2/aqlm/convert_from_hf.ipynb b/examples/models/llama2/aqlm/convert_from_hf.ipynb index 051bf9d8ecc..7a7b754b01c 100644 --- a/examples/models/llama2/aqlm/convert_from_hf.ipynb +++ b/examples/models/llama2/aqlm/convert_from_hf.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -757,8 +757,8 @@ "import torch\n", "from safetensors.torch import load_file\n", "\n", - "LOAD_PATH = \"/Users/blacksamorez/models/Meta-Llama-3.1-8B-AQLM-2Bit-2x8-NO-TUNE/\"\n", - "SAVE_PATH = \"/Users/blacksamorez/models/Meta-Llama-3.1-8B-AQLM-2Bit-2x8-NO-TUNE/\"\n", + "LOAD_PATH = \"/Users/blacksamorez/models/Meta-Llama-3.1-8B-Instruct-AQLM-PV-2Bit-2x8-hf/\"\n", + "SAVE_PATH = \"/Users/blacksamorez/models/Meta-Llama-3.1-8B-Instruct-AQLM-PV-2Bit-2x8-hf/\"\n", "\n", "\n", "with open(LOAD_PATH + \"config.json\", \"r\") as file:\n", diff --git a/examples/models/llama2/aqlm/lut_kernel_pytorch.cpp b/examples/models/llama2/aqlm/lut_kernel_pytorch.cpp index b384113ecf7..2ae23dda885 100644 --- a/examples/models/llama2/aqlm/lut_kernel_pytorch.cpp +++ b/examples/models/llama2/aqlm/lut_kernel_pytorch.cpp @@ -5,6 +5,9 @@ #include +static uint8_t temp_allocator_pool[16 * 1024U * 1024U]; // 16 Mb +static torch::executor::MemoryAllocator temp_allocator(sizeof(temp_allocator_pool), temp_allocator_pool); + namespace torch { namespace executor { namespace native { @@ -16,10 +19,7 @@ namespace torch { const optional bias, Tensor& output ) { - void* memory_pool = malloc(10000000 * sizeof(uint8_t)); - MemoryAllocator allocator(10000000, (uint8_t*)memory_pool); - - exec_aten::RuntimeContext context{nullptr, &allocator}; + exec_aten::RuntimeContext context{nullptr, &temp_allocator}; return torch::executor::native::code2x8_lut_matmat_out( context, input, diff --git a/examples/models/llama2/builder.py b/examples/models/llama2/builder.py index 2a80f2c623e..35ab14a4e41 100644 --- a/examples/models/llama2/builder.py +++ b/examples/models/llama2/builder.py @@ -78,6 +78,7 @@ def load_llama_model( weight_type: WeightType = WeightType.LLAMA, verbose: bool = False, max_seq_len: int = 128, + skip_loading: bool = False, ) -> "LlamaEdgeManager": """ A helper util that builds a Llama2 model. It returns a LlamaEdgeManager that @@ -101,7 +102,7 @@ def load_llama_model( use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, fairseq2=weight_type == WeightType.FAIRSEQ2, max_seq_len=max_seq_len, - skip_loading="AQLM" in params_path, + skip_loading=skip_loading, ) state_dict = model.state_dict() dtype = state_dict[next(iter(state_dict))].dtype diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 724f7f68610..08efca49760 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -377,6 +377,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: weight_type=weight_type, verbose=args.verbose, max_seq_len=args.max_seq_length, + skip_loading=(args.quantization_mode == "aqlm-2x8"), ) .set_output_dir(output_dir_path) .set_metadata(args.metadata) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index b4fe80f0225..ad540cfc8d1 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -130,7 +130,7 @@ class ExecuTorchLlamaJni facebook::jni::alias_ref callback) { runner_->generate( prompt->toStdString(), - 128, + 1024, [callback](std::string result) { callback->onResult(result); }, [callback](const Runner::Stats& result) { callback->onStats(result); }); return 0; From b478f066e7acc4f2129c233bcc58ff27f8230907 Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Fri, 9 Aug 2024 17:19:09 +0300 Subject: [PATCH 14/15] 8da4w for head --- .../models/llama2/source_transformation/quantize.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index a399e626a87..5c2d2c12232 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -155,6 +155,16 @@ def quantize( assign=True, ) + # Quantize model head with 8da4w + if group_size is None: + raise Exception("For 8da4w quantization, group size must be specified.") + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + + model = Int8DynActInt4WeightQuantizer( + precision=torch_dtype, groupsize=group_size + ).quantize(model) + if verbose: + print("quantized model:", model) return model else: raise Exception(f"Unrecognized quantize mode: {qmode}") From 2bfd66ff3f5ceeb66ae70c8725f79ce60dfea18a Mon Sep 17 00:00:00 2001 From: Andrey Panferov Date: Mon, 12 Aug 2024 00:49:52 +0300 Subject: [PATCH 15/15] omp --- examples/models/llama2/aqlm/CmakeLists.txt | 12 ++++++------ examples/models/llama2/aqlm/lut_kernel.cpp | 2 ++ extension/android/CMakeLists.txt | 2 ++ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/models/llama2/aqlm/CmakeLists.txt b/examples/models/llama2/aqlm/CmakeLists.txt index 2da4e01aa5a..3d61678e9db 100644 --- a/examples/models/llama2/aqlm/CmakeLists.txt +++ b/examples/models/llama2/aqlm/CmakeLists.txt @@ -66,24 +66,24 @@ else() list(APPEND aqlm_libs xnnpack_backend) endif() -# find_package(OpenMP REQUIRED) +find_package(OpenMP REQUIRED) # list(APPEND aqlm_libs OpenMP::OpenMP_CXX) -# list(APPEND aqlm_libs omp) +list(APPEND aqlm_libs omp) add_library(aqlm ${_aqlm__srcs}) # Enable optimization -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") # -O3 -fopenmp -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") # -O3 -fopenmp +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") target_include_directories(aqlm PUBLIC "${_common_include_directories}") target_include_directories( aqlm PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../../include" ) -target_link_libraries(aqlm PUBLIC ${aqlm_libs}) +target_link_libraries(aqlm PUBLIC ${aqlm_libs} -fopenmp -static-openmp) target_compile_options( - aqlm PUBLIC ${_common_compile_options} -DET_USE_THREADPOOL # ${OpenMP_CXX_FLAGS} + aqlm PUBLIC ${_common_compile_options} -DET_USE_THREADPOOL ) install(TARGETS aqlm DESTINATION lib) diff --git a/examples/models/llama2/aqlm/lut_kernel.cpp b/examples/models/llama2/aqlm/lut_kernel.cpp index 17af6ebb92c..4a7d8997e2d 100644 --- a/examples/models/llama2/aqlm/lut_kernel.cpp +++ b/examples/models/llama2/aqlm/lut_kernel.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -21,6 +22,7 @@ void quadruple_for( const int b_alt_stride = 2 * out_features; for (int input = 0; input < num_inputs; ++input) { + #pragma omp parallel for num_threads(4) for (int j = 0; j < num_input_groups; ++j) { const fp_dtype* lut_ptr = lut + input * lut_stride + j * 2 * 256; const uint8_t* b_alt_ptr = b_alt + j * b_alt_stride; diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index fcc2531b19f..c7b464875d8 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -129,6 +129,8 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) eigen_blas quantized_kernels quantized_ops_lib + -fopenmp + -static-openmp ) target_compile_options(executorch_llama_jni PUBLIC ${_common_compile_options}) if(EXECUTORCH_USE_TIKTOKEN)