Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 8 additions & 24 deletions src/llmcompressor/modeling/fuse.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from typing import Iterable

import torch
from compressed_tensors import (
align_module_device,
get_execution_device,
update_offload_parameter,
)
from compressed_tensors.offload import update_offload_parameter

__all__ = ["center_embeddings", "fuse_norm_linears"]

Expand All @@ -22,13 +18,9 @@ def center_embeddings(embedding: torch.nn.Module):
if not hasattr(embedding, "weight"):
raise ValueError(f"Cannot fuse norm of type {type(embedding)}")

with align_module_device(embedding):
weight_dtype = embedding.weight.dtype
weight = embedding.weight.to(PRECISION)
new_weight = weight - weight.mean(dim=-1, keepdim=True)
new_weight = new_weight.to(weight_dtype)

update_offload_parameter(embedding, "weight", new_weight)
weight = embedding.weight.to(PRECISION)
weight = weight - weight.mean(dim=-1, keepdim=True)
update_offload_parameter(embedding, "weight", weight)
Comment on lines +21 to +23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The weight tensor is not cast back to its original data type after the centering operation. It remains as PRECISION (torch.float64), which could lead to increased memory usage and potential dtype mismatches in subsequent operations. It's recommended to restore the casting back to the original dtype before updating the parameter.

Suggested change
weight = embedding.weight.to(PRECISION)
weight = weight - weight.mean(dim=-1, keepdim=True)
update_offload_parameter(embedding, "weight", weight)
weight_dtype = embedding.weight.dtype
weight = embedding.weight.to(PRECISION)
weight = weight - weight.mean(dim=-1, keepdim=True)
update_offload_parameter(embedding, "weight", weight.to(weight_dtype))



def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]):
Expand All @@ -46,15 +38,7 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear])

for linear in linears:
# NOTE: spinquant does this op in float64
exec_device = get_execution_device(norm)
with align_module_device(norm, exec_device), align_module_device(
linear, exec_device
):
weight_dtype = linear.weight.dtype
new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION)
new_weight = new_weight.to(weight_dtype)

update_offload_parameter(linear, "weight", new_weight)

new_norm_weight = torch.ones_like(norm.weight, device="cpu")
update_offload_parameter(norm, "weight", new_norm_weight)
linear_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION)
update_offload_parameter(linear, "weight", linear_weight)
Comment on lines +41 to +42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the center_embeddings function, the linear_weight is calculated using PRECISION (torch.float64) but is not cast back to the original linear.weight.dtype. This could lead to unintended dtype changes, increased memory usage, and potential errors. The weight should be cast back to its original data type.

Suggested change
linear_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION)
update_offload_parameter(linear, "weight", linear_weight)
weight_dtype = linear.weight.dtype
linear_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION)
update_offload_parameter(linear, "weight", linear_weight.to(weight_dtype))


update_offload_parameter(norm, "weight", torch.ones_like(norm.weight))
4 changes: 2 additions & 2 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,13 @@ def on_initialize(self, state: State, **kwargs) -> bool:
architecture=state.model.__class__.__name__
)

self._set_resolved_mappings(state.model)

return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True

self._set_resolved_mappings(state.model)

# register quantization calibration hooks
# assume quantization has been initialized by this modifier or one before it
QuantizationMixin.start_calibration(self, state.model)
Expand Down
23 changes: 11 additions & 12 deletions src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,19 @@ def on_initialize(self, state: "State", **kwargs) -> bool:

:param state: session state storing input model and calibration data
"""
# infer module and sequential targets
self.sequential_targets = self._infer_sequential_targets(state.model)

return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True

# find target layers
model: torch.nn.Module = state.model
dataloader: torch.utils.data.DataLoader = state.data.calib

# infer module and sequential targets
self.sequential_targets = self._infer_sequential_targets(model)
layers = get_layers(self.sequential_targets, model)
self._target_layers = get_layers(
self.targets, model
) # layers containing targets
self._target_layers = get_layers(self.targets, model)

# infer layer sparsities
if self.sparsity_profile == "owl":
Expand All @@ -127,7 +131,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
)
self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader)

# get layers and validate sparsity
# validate sparsity
if isinstance(self.sparsity, (list, dict)) and len(self._target_layers) != len(
self.sparsity
):
Expand All @@ -136,11 +140,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
f"sparsities values, but model has {len(layers)} target layers"
)

return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True

# register hooks
for index, (layer_name, layer) in enumerate(self._target_layers.items()):
match self.sparsity:
Expand Down
10 changes: 5 additions & 5 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def on_initialize(self, state: State, **kwargs) -> bool:
if QuantizationMixin.has_config(self):
QuantizationMixin.initialize_quantization(self, state.model)

return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True

# prepare module names
self._module_names = {
m: name
Expand All @@ -171,11 +176,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:
)
}

return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True

# register quantization calibration hooks
# assume quantization has been initialized by this modifier or one before it
QuantizationMixin.start_calibration(self, state.model)
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ def on_initialize(self, state: State, **kwargs) -> bool:
)
self.ignore = [] if not self.ignore else self.ignore
self.mappings = self._infer_mappings_from_model(state.model)
self.resolved_mappings_ = self._resolve_mappings(state.model)
self.scales_ = {}

return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True
self.resolved_mappings_ = self._resolve_mappings(state.model)
self._setup_scale_hooks()

def on_event(self, state: State, event: Event, **kwargs):
Expand Down
57 changes: 7 additions & 50 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple

import torch
from accelerate.hooks import remove_hook_from_module
from compressed_tensors.utils import (
has_offloaded_params,
offloaded_dispatch,
patch_attr,
remove_dispatch,
)
from compressed_tensors.utils import patch_attr
from compressed_tensors.utils.match import match_targets
from loguru import logger
from torch.fx import Graph, GraphModule, Node
Expand All @@ -37,7 +31,6 @@
"trace_subgraphs",
"Subgraph",
"get_sequential_targets",
"dispatch_for_sequential",
]


Expand Down Expand Up @@ -104,10 +97,9 @@ def trace_subgraphs(
# find modules
targets = match_modules(model, sequential_targets)
ancestors = get_sequential_ancestors(model, targets)
offloaded = set(m for m in model.modules() if has_offloaded_params(m))

# initialize arguments
tracer = SequentialTracer(ancestors, offloaded)
tracer = SequentialTracer(ancestors)
concrete_args = populate_concrete_args(model, sample_input)

with contextlib.ExitStack() as stack:
Expand Down Expand Up @@ -168,32 +160,18 @@ class SequentialTracer(HFTracer):
"""
Get a tracer specialized for the given model. The resulting tracer will not trace
inside of sequential targets, nor any modules which are not call graph ancestors of
sequential targets

Tracing within sequential targets is unnecessary, and tracing within offloaded
modules may result in meta tensors being added to the model graph
sequential targets. Tracing outside of call ancestors of sequential targets will be
skipped

:param ancestors: modules which are ancestors of sequential targets
:param offloaded: modules which have offloaded params and should not be traced
"""

def __init__(self, ancestors: Set[Module], offloaded: Set[Module]):
def __init__(self, ancestors: Set[Module]):
self.ancestors = ancestors
self.offloaded = offloaded

# skip any mask creation functions not already caught by the autowrapper
super().__init__(autowrap_functions=_get_autowrap_functions())

# check unlikely case that ancestors have direct params which are offloaded
offloaded_ancestors = offloaded & ancestors
for ancestor in offloaded_ancestors:
remove_hook_from_module(ancestor, recurse=False)
self.offloaded.remove(ancestor)
logger.warning(
f"Direct parameters attached to {ancestor.__class__.__name__} have "
"been onloaded in order to ensure safe graph capture and execution"
)

def create_arg(self, a: Any) -> Argument:
# special extension allows models which depend on config values to be traced
if isinstance(a, PretrainedConfig):
Expand All @@ -204,8 +182,8 @@ def create_arg(self, a: Any) -> Argument:
return super().create_arg(a)

def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool:
# do not trace non-ancestors or modules with offloaded params
return module not in self.ancestors or module in self.offloaded
# do not trace non-ancestors
return module not in self.ancestors


def populate_concrete_args(model: Module, sample_input: Dict) -> Dict:
Expand Down Expand Up @@ -526,27 +504,6 @@ def is_ancestor(module: Module) -> bool:
return ancestors


def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel:
"""
Dispatch a model for sequential calibration using a sequential pipeline.
The model will be offloaded to the CPU and dispatched to CUDA/XPU device
if available. Removes any existing hooks.

:param model: model to dispatch
:return: dispatched model
"""
remove_dispatch(model)

if torch.cuda.is_available():
offloaded_dispatch(model, execution_device=torch.device("cuda:0"))
elif hasattr(torch, "xpu") and torch.xpu.is_available():
offloaded_dispatch(model, execution_device=torch.device("xpu:0"))
else:
logger.warning("CUDA/XPU is not available! Compressing model on CPU instead")

return model


def _get_autowrap_functions() -> Tuple[Callable[[Any], Any], ...]:
try:
from transformers.masking_utils import LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING
Expand Down
12 changes: 6 additions & 6 deletions src/llmcompressor/pipelines/sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import TYPE_CHECKING

import torch
from compressed_tensors.utils import disable_offloading, get_execution_device
from compressed_tensors.offload import dispatch_model
from compressed_tensors.utils import disable_offloading
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

Expand All @@ -11,7 +12,6 @@
from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.pipelines.registry import CalibrationPipeline
from llmcompressor.pipelines.sequential.helpers import (
dispatch_for_sequential,
get_sequential_targets,
trace_subgraphs,
)
Expand Down Expand Up @@ -59,10 +59,6 @@ def __call__(
"""
session = active_session()

# prepare model for sequential onloading
dispatch_for_sequential(model)
model_device = get_execution_device(model)

# prepare to trace subgraphs
modifiers = session.lifecycle.recipe.modifiers
sequential_targets = get_sequential_targets(modifiers, model, dataset_args)
Expand All @@ -73,6 +69,10 @@ def __call__(
subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore)
num_subgraphs = len(subgraphs)

# prepare model for sequential onloading
model_device = "cuda" if torch.cuda.is_available() else "cpu"
dispatch_model(model, model_device)
Comment on lines +73 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The previous implementation in dispatch_for_sequential included a check for torch.xpu.is_available() to support Intel XPU devices. This change removes that check, which is a regression in functionality. It would be beneficial to reintroduce XPU support to maintain broader hardware compatibility.

Suggested change
model_device = "cuda" if torch.cuda.is_available() else "cpu"
dispatch_model(model, model_device)
if torch.cuda.is_available():
model_device = "cuda"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
model_device = "xpu"
else:
model_device = "cpu"
dispatch_model(model, model_device)


LifecycleCallbacks.calibration_epoch_start()

# TODO: remove this to enable quantization aware calibration
Expand Down
12 changes: 3 additions & 9 deletions src/llmcompressor/utils/transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from compressed_tensors import has_offloaded_params, register_offload_parameter
from loguru import logger
from torch.nn import Parameter
from transformers import PreTrainedModel
Expand Down Expand Up @@ -28,14 +27,9 @@ def untie_word_embeddings(model: PreTrainedModel):

# clone data to untie
for module in (input_embed, output_embed):
if not has_offloaded_params(module):
data = module.weight.data
else:
data = module._hf_hook.weights_map["weight"]

requires_grad = module.weight.requires_grad
untied_param = Parameter(data.clone(), requires_grad=requires_grad)
register_offload_parameter(module, "weight", untied_param)
weight = module.weight
param = Parameter(weight.data.clone(), requires_grad=weight.requires_grad)
module.register_parameter("weight", param)

# modify model config
if hasattr(model.config, "tie_word_embeddings"):
Expand Down
16 changes: 7 additions & 9 deletions tests/llmcompressor/utils/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest
import torch
from compressed_tensors.offload import dispatch_model
from transformers import (
AutoModelForCausalLM,
MllamaForConditionalGeneration,
)

from llmcompressor.pipelines.sequential.helpers import dispatch_for_sequential
from llmcompressor.utils import (
ALL_TOKEN,
DisableQuantization,
Expand All @@ -17,7 +17,7 @@
interpolate,
validate_str_iterable,
)
from llmcompressor.utils.dev import dispatch_for_generation, skip_weights_download
from llmcompressor.utils.dev import skip_weights_download
from tests.testing_utils import requires_gpu


Expand Down Expand Up @@ -153,14 +153,12 @@ def test_disable_cache(model_cls, model_stub):


@requires_gpu
@pytest.mark.parametrize("offload", ["sequential", "basic", "none"])
def test_disable_lm_head(offload):
@pytest.mark.parametrize("dispatch", (True, False))
def test_disable_lm_head(dispatch):
model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2")
if offload == "sequential":
dispatch_for_sequential(model)
if offload == "basic":
dispatch_for_generation(model)
if offload == "none":
if dispatch:
dispatch_model(model, "cuda")
else:
model = model.to("cuda")

lm_input_device = None
Expand Down
Loading