Skip to content
Open
4 changes: 2 additions & 2 deletions examples/autoround/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pip install -e .

## Quickstart

The example includes an end-to-end script for applying the AutoRound quantization algorithm.
The [example](./quantization_w4a16/llama3_example.py) includes an end-to-end script for applying the AutoRound quantization algorithm.

```bash
python3 llama3_example.py
Expand Down Expand Up @@ -134,7 +134,7 @@ We can see the resulting scores look good!
> Note: quantized model accuracy may vary slightly due to nondeterminism.

### Known Issues
Currently, `llm-compressor` supports applying AutoRound only on the `wNa16` quantization schemes. Support for additional schemes is planned. You can follow progress in the [RFC](https://github.com/vllm-project/llm-compressor/issues/1968).
Currently, `llm-compressor` supports applying AutoRound only on the `wNa16` and `w8a8` quantization schemes. Support for additional schemes is planned. You can follow progress in the [RFC](https://github.com/vllm-project/llm-compressor/issues/1968).

### Questions or Feature Request?

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from auto_round.calib_dataset import get_dataset
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, AutoProcessor

from llmcompressor import oneshot
from llmcompressor.modifiers.autoround import AutoRoundModifier
from llmcompressor.utils import dispatch_for_generation

# Select model and load it.
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

# Select calibration dataset.
NUM_CALIBRATION_SAMPLES = 128
MAX_SEQUENCE_LENGTH = 2048
# Get aligned calibration dataset.

ds = get_dataset(
tokenizer=tokenizer,
seqlen=MAX_SEQUENCE_LENGTH,
nsamples=NUM_CALIBRATION_SAMPLES,
)


# Configure the quantization algorithm to run.
recipe = AutoRoundModifier(
targets="Linear",
scheme="FP8_DYNAMIC",
ignore=["re:.*lm_head", "re:.*router", "re:.*self_attn.*", "re:.*shared_expert.*" , "re:multi_modal_projector.*", "re:vision_model"],
iters=0,
)


# Apply algorithms.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
shuffle_calibration_samples=False,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=1)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W8A8-Dynamic-AutoRound"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
processor.save_pretrained(SAVE_DIR)
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from auto_round.calib_dataset import get_dataset
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, AutoProcessor

from llmcompressor import oneshot
from llmcompressor.modifiers.autoround import AutoRoundModifier
from llmcompressor.utils import dispatch_for_generation

# Select model and load it.
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

# Select calibration dataset.
NUM_CALIBRATION_SAMPLES = 128
MAX_SEQUENCE_LENGTH = 2048
# Get aligned calibration dataset.

ds = get_dataset(
tokenizer=tokenizer,
seqlen=MAX_SEQUENCE_LENGTH,
nsamples=NUM_CALIBRATION_SAMPLES,
)


# Configure the quantization algorithm to run.
recipe = AutoRoundModifier(
targets="Linear",
scheme="FP8",
ignore=["re:.*lm_head", "re:.*router", "re:.*self_attn.*", "re:.*shared_expert.*" , "re:multi_modal_projector.*", "re:vision_model"],
iters=0,
)


# Apply algorithms.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
shuffle_calibration_samples=False,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=1)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W8A8-Static-AutoRound"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
processor.save_pretrained(SAVE_DIR)
109 changes: 93 additions & 16 deletions src/llmcompressor/modifiers/autoround/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
from auto_round import AutoRound
from auto_round.schemes import QuantizationScheme as ARQuantizationScheme
from auto_round.wrapper import WrapperWALayer

from compressed_tensors.quantization import (
QuantizationScheme,
QuantizationStrategy,
Expand Down Expand Up @@ -145,8 +147,7 @@ def start_calibration(self, model: torch.nn.Module):
untie_word_embeddings(model)

for _, module in match_named_modules(model, self.targets, self.ignore):
# Note: No need to register observers for auto-round
self._calibration_hooks |= self._initialize_hooks(module)
# skip register observers for auto-round
apply_calibration_status(module)

model.apply(enable_quantization) # quantize at the same time as calibrate
Expand Down Expand Up @@ -214,6 +215,7 @@ def apply_autoround(self, state, subgraph):

wrapped_model = _wrap_decoding_layer(decoding_layer)
wrapped_model.name_or_path = state.model.name_or_path
wrapped_model.config = state.model.config

with torch.enable_grad(), align_module_device(decoding_layer):
ar_quant_scheme = self._mapping_config_to_autoround()
Expand Down Expand Up @@ -242,16 +244,35 @@ def apply_autoround(self, state, subgraph):
auto_offload=False,
)
self._q_input = q_input

decoding_layer = self._unwrapper_quantized_layer(decoding_layer)

# Update offload parameters and remove temporary attributes
for _, module in decoding_layer.named_modules():
if hasattr(module, "weight_scale") and hasattr(
module, "weight_zero_point"
for name, module in decoding_layer.named_modules():
if (
hasattr(module, "weight_scale")
and hasattr(module, "weight_zero_point")
and hasattr(module, "scale")
):
# Note: The model's weight is already q-dq in-place by auto-round.
weight_scale = module.scale
del module.scale
# TODO: update zero_point after supporting asymmetric quantization
update_offload_parameter(module, "weight_scale", weight_scale)

if (
hasattr(module, "act_scale")
and hasattr(module, "input_scale")
):
act_scale = module.act_scale
assert act_scale.numel() == module.input_scale.numel(), (
f"Expected act_scale of size {module.input_scale.numel()}, got {act_scale.numel()}"
)
del module.act_scale

# activation scale shape maybe different
update_offload_parameter(module, "input_scale", act_scale.reshape(module.input_scale.shape))

decoding_layer.eval()

def post_autoround_cleanup(self):
Expand All @@ -278,6 +299,19 @@ def on_finalize(self, state: State, **kwargs) -> bool:

return True

def _unwrapper_quantized_layer(self, model: torch.nn.Module):
# auto-round will return WrapperWALayer if activation is quantized
for name, module in model.named_modules():
if isinstance(module, WrapperWALayer):
if "." in name:
parent, child = name.rsplit(".", maxsplit=1)
parent = model.get_submodule(parent)
setattr(parent, child, module.orig_layer)
else:
# It's a top-level module
setattr(model, name, module.orig_layer)
return model

def _add_temporary_names(self, model: torch.nn.Module):
for name, mod in model.named_modules():
mod._tmp_name = name
Expand Down Expand Up @@ -314,23 +348,66 @@ def _mapping_config_to_autoround(self):
), f"Expected QuantizationScheme, got {type(scheme)}"
quant_scheme = scheme
weight_args = quant_scheme.weights
assert weight_args.strategy == QuantizationStrategy.GROUP, (
"Only group-wise quantization is supported in AutoRoundModifier for now, "
f"got {weight_args.strategy}"
)
assert quant_scheme.input_activations is None, (
"Input activation quantization is not supported in AutoRoundModifier, "
f"got {quant_scheme.input_activations}"
)
activation_args = quant_scheme.input_activations
assert quant_scheme.output_activations is None, (
"Output activation quantization is not supported in AutoRoundModifier, "
f"got {quant_scheme.output_activations}"
)
group_size = weight_args.group_size
data_type = weight_args.type
if group_size is None:
if weight_args.strategy == QuantizationStrategy.CHANNEL:
group_size = -1
elif weight_args.strategy == QuantizationStrategy.TENSOR:
group_size = 0
else:
raise ValueError(
"AutoRoundModifier only supports channel-wise and tensor-wise weight quantization"
)

if data_type == "float":
data_type = "fp"

if activation_args is None:
act_bits = 16
act_group_size = None
act_symmetric = None
act_dynamic = None
act_data_type = None
else:
act_dynamic = activation_args.dynamic
act_group_size = activation_args.group_size
act_symmetric = activation_args.symmetric
act_bits = activation_args.num_bits
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using act_dynamic = getattr(activation_args, "dynamic", None)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are default values in QuantizationArgs for each parameter. If we use getattr, I think all similar codes should be replaced to keep aligned.


# activation is quantized dynamically, don't need to collect scale in auto-round
if act_dynamic:
act_bits = 16

act_data_type = activation_args.type
assert activation_args.strategy != QuantizationStrategy.GROUP, (
"Input activation group-wise quantization is not supported in AutoRoundModifier"
)
if act_group_size is None:
if activation_args.strategy in [QuantizationStrategy.CHANNEL, QuantizationStrategy.TOKEN]:
act_group_size = -1
elif activation_args.strategy == QuantizationStrategy.TENSOR:
act_group_size = 0
else:
raise ValueError(f"{activation_args.strategy} is not supported in AutoRoundModifier")

if act_data_type == "float":
act_data_type = "fp"

ar_quant_scheme = ARQuantizationScheme(
bits=weight_args.num_bits,
sym=weight_args.symmetric,
group_size=weight_args.group_size,
data_type=weight_args.type,
act_bits=16,
group_size=group_size,
data_type=data_type,
act_bits=act_bits,
act_group_size=act_group_size,
act_sym=act_symmetric,
act_dynamic=act_dynamic,
act_data_type=act_data_type,
)
return ar_quant_scheme
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,29 @@
},
)

w8a8_dynamic_recipe_modifier = AutoRoundModifier(
ignore=["lm_head"],
iters=0,
config_groups={
"group_0": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(num_bits=8, strategy="channel"),
input_activations=QuantizationArgs(num_bits=8, type="float", strategy="token", dynamic=True),
)
},
)

w8a8_static_recipe_modifier = AutoRoundModifier(
ignore=["lm_head"],
iters=0,
config_groups={
"group_0": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(num_bits=8, type="float", strategy="tensor"),
input_activations=QuantizationArgs(num_bits=8, type="float", strategy="tensor"),
)
},
)

@requires_gpu(1)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -94,3 +117,57 @@ def test_oneshot_application(recipe, tmp_path):
# Check lm-head is not quantized
not_targetted = model_loaded.lm_head
assert not hasattr(not_targetted, "quantization_scheme")

@requires_gpu(1)
@pytest.mark.parametrize(
"recipe",
[
w8a8_dynamic_recipe_modifier,
w8a8_static_recipe_modifier
],
)
def test_rtn_oneshot(recipe, tmp_path):
output = tmp_path / "oneshot_output"
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model)
dataset = get_dataset(
tokenizer=tokenizer,
seqlen=1024,
nsamples=32,
)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

oneshot(
model=model,
dataset=dataset,
output_dir=output,
recipe=recipe,
)
model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map=device)

quantization_config = model_loaded.config.quantization_config.quantization_config
assert quantization_config is not None

# check config is set properly
assert "lm_head" in quantization_config.ignore
assert len(quantization_config.config_groups) == 1
quant_scheme = quantization_config.config_groups["group_0"]
assert isinstance(quant_scheme, QuantizationScheme)

weight_args = quantization_config.config_groups["group_0"].weights
act_args = quantization_config.config_groups["group_0"].input_activations
assert isinstance(weight_args, QuantizationArgs)
assert weight_args.num_bits == recipe.config_groups["group_0"].weights.num_bits
assert weight_args.strategy == recipe.config_groups["group_0"].weights.strategy
if act_args is not None:
assert act_args.num_bits == recipe.config_groups["group_0"].input_activations.num_bits
assert act_args.strategy== recipe.config_groups["group_0"].input_activations.strategy

# Check a specific layer is quantized
targetted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj
assert hasattr(targetted_linear_layer, "quantization_scheme")

# Check lm-head is not quantized
not_targetted = model_loaded.lm_head
assert not hasattr(not_targetted, "quantization_scheme")