Skip to content
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions examples/autoround/qwen3_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from auto_round.calib_dataset import get_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.autoround import AutoRoundModifier
from llmcompressor.utils import dispatch_for_generation

# Select model and load it.
model_id = "Qwen/Qwen3-235B-A22B/"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Select calibration dataset.
NUM_CALIBRATION_SAMPLES = 128
MAX_SEQUENCE_LENGTH = 2048
ITERS = 200
# Get aligned calibration dataset.

ds = get_dataset(
tokenizer=tokenizer,
seqlen=MAX_SEQUENCE_LENGTH,
nsamples=NUM_CALIBRATION_SAMPLES,
)


# Configure the quantization algorithm to run.
# * quantize the weights to 4 bit with AutoRound with a group size 128
# * For `Qwen/Qwen3-235B-A22B`, it requires more about 300 GB memory to run tuning with default settings.
recipe = AutoRoundModifier(
targets="Linear",
scheme="W4A16",
ignore=[
"lm_head",
"re:.*mlp.gate$",
],
iters=ITERS,
enable_torch_compile=False,
device_map="0,1,2,3", # Use 4 A100 GPUs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think the name here is confusing, isn't device_map usually a different format like a dict that maps layer name to device id? If device_map="0,1,2,3" is valid in transformers, we can leave as is, otherwise device_ids may be a better name

Copy link
Contributor Author

@yiliu30 yiliu30 Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the device_map is used in Transformers, and we follow a similar approach. Please refer to:
https://huggingface.co/docs/accelerate/en/usage_guides/big_modeling#accelerate

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've only ever seen device_map be a string like "auto" or "sequential", or a dictionary mapping each module name to each device_id, like

device_map = {"block1": 0, "block2.linear1": 0, "block2.linear2": 1, "block2.linear3": 1}

What does it mean if device_map="0,1,2,3"? Is that like auto but only with the first 4 devices?

Reference: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#designing-a-device-map

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you're right. Updated to device_ids!

)


# Apply algorithms.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
shuffle_calibration_samples=False,
)


# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128-AutoRound"
print(f"save to {SAVE_DIR}")
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)


# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
66 changes: 60 additions & 6 deletions src/llmcompressor/modifiers/autoround/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from accelerate.hooks import add_hook_to_module, remove_hook_from_submodules
from auto_round import AutoRound
from auto_round.schemes import QuantizationScheme as ARQuantizationScheme
from compressed_tensors.quantization import (
Expand Down Expand Up @@ -54,6 +57,33 @@ def _wrap_decoding_layer(layer: torch.nn.Module) -> _PretrainModelWrapper:
return wrapped_model


@contextmanager
def suspend_accelerate_hooks(model: nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just fyi, we are refactoring our usage of accelerate hooks for offloading. You can follow some of that in

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I noticed that change as well. We can adapt to it once it’s ready.
And could I ask what motivated you to implement that functionality yourself instead of using the accelerator hooks? I imagine it requires quite a bit of engineering effort.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have anything posted on the decision to move away from accelerate hooks, outside of it being a pain point and our usage of it being limited in scope. cc @kylesayrs in case there is any other information we can provide.

"""
Temporarily suspend Accelerate hooks from a model.

This context manager detaches all Accelerate hooks (used for device offloading,
dtype casting, etc.) from the model, allowing Autoround to operate without interference.
On exit, the model is restored to its original device and all hooks are re-attached.
"""
saved_hooks = {}
original_device = next(model.parameters()).device
for name, module in model.named_modules():
if hasattr(module, "_hf_hook"):
saved_hooks[name] = module._hf_hook

remove_hook_from_submodules(model)
try:
yield
finally:
remove_hook_from_submodules(model)
model.to(original_device)
for name, module in model.named_modules():
if name in saved_hooks:
logger.info("Restoring Accelerate hook for module: {}", name)
add_hook_to_module(module, saved_hooks[name], append=True)


class AutoRoundModifier(Modifier, QuantizationMixin):
"""
Implements the AutoRound algorithm from https://aclanthology.org/2024.findings-emnlp.662.pdf.
Expand Down Expand Up @@ -110,6 +140,10 @@ class AutoRoundModifier(Modifier, QuantizationMixin):
iters: int = 200
enable_torch_compile: bool = True
batch_size: int = 8
# optional device map for layer dispatch during tuning
# examples: "0,1" for cuda:0,cuda:1; "auto" to use all available GPUs
# when None, no dispatching and the model stays on its current device
device_map: Optional[str] = None

# private variables
_all_module_input: Dict[str, List[Tuple]] = PrivateAttr(default_factory=dict)
Expand Down Expand Up @@ -215,15 +249,20 @@ def apply_autoround(self, state, subgraph):
wrapped_model = _wrap_decoding_layer(decoding_layer)
wrapped_model.name_or_path = state.model.name_or_path

with torch.enable_grad(), align_module_device(decoding_layer):
with torch.enable_grad(), align_module_device(
decoding_layer
), suspend_accelerate_hooks(wrapped_model):
ar_quant_scheme = self._mapping_config_to_autoround()
fp_layers = self.get_unquantized_layer_names(decoding_layer)
ar = AutoRound(
model=wrapped_model,
tokenizer="",
scheme=ar_quant_scheme,
iters=self.iters,
enable_torch_compile=self.enable_torch_compile,
batch_size=self.batch_size,
device_map=self.device_map,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just fyi, we are looking into better parallelization support and will create an RFC in the new year to gather feedback on best approaches. See PR and comment here

Copy link
Contributor Author

@yiliu30 yiliu30 Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I’ve looked through some of the multi‑card discussions in LLMC, and they’re quite insightful.
In AutoRound, we currently chose the accelerator hooks because they’re general enough to work across most models without requiring explicit cross‑card communication ops or modeling changes. The downside, of course, is some communication overhead and limited overlap, which can affect performance.

We’re also exploring more efficient ways to fully squeeze out GPU performance. Looking forward to the RFC from you all, hope it covers the tuning case if possible!

Copy link
Collaborator

@brian-dellabetta brian-dellabetta Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will be sure to share it out in the new year.

Can you elaborate on what you mean by the tuning case? Is this specific to the tuning stage mentioned in the SignRoundv2 paper?

Copy link
Contributor Author

@yiliu30 yiliu30 Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! The tuning here refers to fine‑tuning the quantization parameters by evaluating the block‑wise reconstruction error. In this process, we compute the loss between the original floating‑point model and the Q‑DQ model, and then run a backward pass to update the gradients of the quantization parameters accordingly. This approach was introduced in SignRound v1. cc @wenhuach

image

For implementation details, please refer to the code here. https://github.com/intel/auto-round/blob/440288fd6b92509e84da337437a30997ac544735/auto_round/compressors/base.py#L2984

fp_layers=",".join(fp_layers) if fp_layers else "",
)
# TODO: configure layer-wise config based on self.resolved_config
ar.configure_layer_config(enable_gguf_official_mixed=False)
Expand All @@ -232,21 +271,25 @@ def apply_autoround(self, state, subgraph):
device = first_param.device
cur_inputs = self._all_module_input[decoding_layer._tmp_name]
decoding_layer.tuning_device = device
# Leave offload for LLMC to handle if `device_map` is not set
auto_offload = False
if self.device_map is not None:
# When device_map is set, we move decoding layer to CPU first,
# then the submodules will be re-dispatched by AutoRound.
decoding_layer.to("cpu")
auto_offload = True

q_input, _ = ar.quantize_block(
block=decoding_layer,
inputs=cur_inputs,
q_input=self._q_input,
device=str(device),
# Leave offload for LLMC
auto_offload=False,
auto_offload=auto_offload,
)
self._q_input = q_input
# Update offload parameters and remove temporary attributes
for _, module in decoding_layer.named_modules():
if hasattr(module, "weight_scale") and hasattr(
module, "weight_zero_point"
):
if hasattr(module, "scale") and hasattr(module, "weight_zero_point"):
# Note: The model's weight is already q-dq in-place by auto-round.
weight_scale = module.scale
del module.scale
Expand Down Expand Up @@ -278,6 +321,17 @@ def on_finalize(self, state: State, **kwargs) -> bool:

return True

def get_unquantized_layer_names(self, wrapped_model: torch.nn.Module) -> List[str]:
unquantized_layers = []

for name, module in wrapped_model.named_modules():
if (
module.__class__.__name__ in self.resolved_targets
and getattr(module, "quantization_scheme", None) is None
):
unquantized_layers.append(name)
return unquantized_layers

def _add_temporary_names(self, model: torch.nn.Module):
for name, mod in model.named_modules():
mod._tmp_name = name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,64 @@ def test_oneshot_application(recipe, tmp_path):
# Check lm-head is not quantized
not_targetted = model_loaded.lm_head
assert not hasattr(not_targetted, "quantization_scheme")


@requires_gpu(2)
def test_oneshot_with_device_map(tmp_path):
output = tmp_path / "oneshot_output"
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model)
dataset = get_dataset(
tokenizer=tokenizer,
seqlen=512,
nsamples=4,
)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

recipe = AutoRoundModifier(
ignore=["lm_head"],
iters=10,
config_groups={
"group_0": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(num_bits=4, strategy="group", group_size=128),
)
},
device_map="0,1",
)

oneshot(
model=model,
dataset=dataset,
output_dir=output,
recipe=recipe,
)
model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map=device)

# Check that the model is quantized
# for compression_config - decompress() will attach a quantization_config
# to the model as we decompress right away
# for quantization_config - we have CompressedLinear which will only
# decompress on the forward pass and does not call decompress(). Results
# in a slightly different parameter tree to access the quant config
quantization_config = model_loaded.config.quantization_config.quantization_config
assert quantization_config is not None

# check config is set properly
assert "lm_head" in quantization_config.ignore
assert len(quantization_config.config_groups) == 1
quant_scheme = quantization_config.config_groups["group_0"]
assert isinstance(quant_scheme, QuantizationScheme)

weight_args = quantization_config.config_groups["group_0"].weights
assert isinstance(weight_args, QuantizationArgs)
assert weight_args.num_bits == 4

# Check a specific layer is quantized
targetted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj
assert hasattr(targetted_linear_layer, "quantization_scheme")

# Check lm-head is not quantized
not_targetted = model_loaded.lm_head
assert not hasattr(not_targetted, "quantization_scheme")