Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 85 additions & 17 deletions src/llmcompressor/modifiers/autoround/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@

import torch
from auto_round import AutoRound
from auto_round.schemes import PRESET_SCHEMES as AR_PRESET_SCHEMES
from auto_round.schemes import QuantizationScheme as ARQuantizationScheme
from auto_round.utils import is_mllm_model
from auto_round.wrapper import WrapperWALayer
from compressed_tensors.quantization import (
QuantizationMetadata,
QuantizationScheme,
QuantizationStrategy,
enable_quantization,
)
from compressed_tensors.utils import (
align_module_device,
match_named_modules,
update_offload_parameter,
register_offload_parameter,
)
from loguru import logger
from pydantic import PrivateAttr
from transformers import AutoProcessor, AutoTokenizer

from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier
Expand Down Expand Up @@ -54,6 +59,10 @@ def _wrap_decoding_layer(layer: torch.nn.Module) -> _PretrainModelWrapper:
return wrapped_model


def _unwrap_decoding_layer(wrapped_model: _PretrainModelWrapper) -> torch.nn.Module:
return wrapped_model.model.layers


class AutoRoundModifier(Modifier, QuantizationMixin):
"""
Implements the AutoRound algorithm from https://aclanthology.org/2024.findings-emnlp.662.pdf.
Expand Down Expand Up @@ -146,8 +155,9 @@ def start_calibration(self, model: torch.nn.Module):

for _, module in match_named_modules(model, self.targets, self.ignore):
# Note: No need to register observers for auto-round
self._calibration_hooks |= self._initialize_hooks(module)
apply_calibration_status(module)
# Remove the registered qparams to avoid naming conflicts with AutoRound.
QuantizationMetadata.clear_all_qparams(module)

model.apply(enable_quantization) # quantize at the same time as calibrate

Expand All @@ -163,7 +173,7 @@ def on_start(self, state: State, event: Event, **kwargs):
# assume quantization has been initialized by this modifier or one before it
self.start_calibration(state.model)
for _, module in state.model.named_modules():
if self._is_decoding_layer(module):
if self._is_decoding_layer(module) and self.iters > 0:
# register input capture hook for decoding layers
self.register_hook(
module, self.input_capture_hook, "forward_pre", with_kwargs=True
Expand Down Expand Up @@ -214,12 +224,23 @@ def apply_autoround(self, state, subgraph):

wrapped_model = _wrap_decoding_layer(decoding_layer)
wrapped_model.name_or_path = state.model.name_or_path
wrapped_model.config = state.model.config
tokenizer = AutoTokenizer.from_pretrained(
wrapped_model.name_or_path, trust_remote_code=True
)
if is_mllm_model(wrapped_model):
processor = AutoProcessor.from_pretrained(
wrapped_model.name_or_path, trust_remote_code=True
)
else:
processor = None

with torch.enable_grad(), align_module_device(decoding_layer):
ar_quant_scheme = self._mapping_config_to_autoround()
ar = AutoRound(
model=wrapped_model,
tokenizer="",
tokenizer=tokenizer,
processor=processor,
scheme=ar_quant_scheme,
iters=self.iters,
enable_torch_compile=self.enable_torch_compile,
Expand All @@ -230,9 +251,9 @@ def apply_autoround(self, state, subgraph):
ar.batch_dim = 0
first_param = next(decoding_layer.parameters())
device = first_param.device
cur_inputs = self._all_module_input[decoding_layer._tmp_name]
decoding_layer.tuning_device = device

cur_inputs = self._all_module_input[decoding_layer._tmp_name]
q_input, _ = ar.quantize_block(
block=decoding_layer,
inputs=cur_inputs,
Expand All @@ -243,15 +264,8 @@ def apply_autoround(self, state, subgraph):
)
self._q_input = q_input
# Update offload parameters and remove temporary attributes
for _, module in decoding_layer.named_modules():
if hasattr(module, "weight_scale") and hasattr(
module, "weight_zero_point"
):
# Note: The model's weight is already q-dq in-place by auto-round.
weight_scale = module.scale
del module.scale
# TODO: update zero_point after supporting asymmetric quantization
update_offload_parameter(module, "weight_scale", weight_scale)
decoding_layer = self._unwrapper_quantized_layer(decoding_layer)
self._mapping_qparams(decoding_layer)
decoding_layer.eval()

def post_autoround_cleanup(self):
Expand Down Expand Up @@ -299,7 +313,61 @@ def _infer_sequential_targets(self, model: torch.nn.Module) -> str | list[str]:
case _:
return self.sequential_targets

def _unwrapper_quantized_layer(self, model: torch.nn.Module):
# auto-round will return WrapperWALayer if activation is quantized
for name, module in model.named_modules():
if isinstance(module, WrapperWALayer):
if "." in name:
parent, child = name.rsplit(".", maxsplit=1)
parent = model.get_submodule(parent)
setattr(parent, child, module.orig_layer)
else:
# It's a top-level module
setattr(model, name, module.orig_layer)
return model

def _mapping_qparams(self, decoding_layer):
"""Mapping qparam name from AutoRound to LLMC and register qparams in model."""
qparams_mapping = {
# AutoRound parameter name: LLMCompressor parameter name
"scale": "weight_scale",
"zp": "weight_zero_point",
"input_scale": "input_scale",
"weight_global_scale": "weight_global_scale",
"act_max": "input_global_scale",
}
# Update offload parameters and remove temporary attributes
for name, module in decoding_layer.named_modules():
for ar_param_name, llmc_param_name in qparams_mapping.items():
if hasattr(module, ar_param_name):
ar_value = getattr(module, ar_param_name)
if ar_value is None:
ar_value = torch.empty(1)
Comment on lines +344 to +345
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of torch.empty(1) to handle None quantization parameters is unsafe. torch.empty() returns a tensor with uninitialized data, which can lead to non-deterministic behavior and incorrect results when used as a quantization parameter.

A None value should be handled based on the specific parameter's meaning. For instance, a None zero-point (zp) typically indicates symmetric quantization and should default to 0.0. For other parameters like scales, if None signifies they are not applicable for the current quantization scheme, they should be skipped.

                    if ar_value is None:
                        if ar_param_name == "zp":
                            # For symmetric quantization, zero point is 0
                            ar_value = 0.0
                        else:
                            # If other parameters are None, it likely means they are not used
                            # for the current quantization scheme, so we can skip them.
                            continue

if not isinstance(ar_value, torch.Tensor):
ar_value = torch.tensor(ar_value)
if ar_param_name == "act_max" and self.scheme == "NVFP4":
from auto_round.data_type.nvfp import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
get_reciprocal,
)

param_value = torch.nn.Parameter(
FLOAT8_E4M3_MAX
* FLOAT4_E2M1_MAX
* get_reciprocal(ar_value),
requires_grad=False,
)
else:
param_value = torch.nn.Parameter(ar_value, requires_grad=False)
delattr(module, ar_param_name)
register_offload_parameter(module, llmc_param_name, param_value)

def _mapping_config_to_autoround(self):
if isinstance(self.scheme, str):
if self.scheme in AR_PRESET_SCHEMES:
return self.scheme

resolved_config = self.resolved_config
quant_scheme = None
# TODO: release below constraint in later PRs
Expand All @@ -309,9 +377,9 @@ def _mapping_config_to_autoround(self):
)

for scheme in resolved_config.config_groups.values():
assert isinstance(
scheme, QuantizationScheme
), f"Expected QuantizationScheme, got {type(scheme)}"
assert isinstance(scheme, QuantizationScheme), (
f"Expected QuantizationScheme, got {type(scheme)}"
)
quant_scheme = scheme
weight_args = quant_scheme.weights
assert weight_args.strategy == QuantizationStrategy.GROUP, (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
)
},
)
recipe_modifier_nvfp4 = AutoRoundModifier(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about adding an example of NVFP4 model quantization, besides adding NVFP4 test?

ignore=["lm_head"],
iters=2,
scheme="NVFP4",
)


@requires_gpu(1)
Expand All @@ -46,6 +51,7 @@
[
recipe_str,
recipe_modifier_full,
recipe_modifier_nvfp4,
],
)
def test_oneshot_application(recipe, tmp_path):
Expand Down