-
Notifications
You must be signed in to change notification settings - Fork 346
[TorchOffloader] Code Cleanup #2147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,11 +1,7 @@ | ||||||||||||
| from typing import Iterable | ||||||||||||
|
|
||||||||||||
| import torch | ||||||||||||
| from compressed_tensors import ( | ||||||||||||
| align_module_device, | ||||||||||||
| get_execution_device, | ||||||||||||
| update_offload_parameter, | ||||||||||||
| ) | ||||||||||||
| from compressed_tensors.offload import update_offload_parameter | ||||||||||||
|
|
||||||||||||
| __all__ = ["center_embeddings", "fuse_norm_linears"] | ||||||||||||
|
|
||||||||||||
|
|
@@ -22,13 +18,9 @@ def center_embeddings(embedding: torch.nn.Module): | |||||||||||
| if not hasattr(embedding, "weight"): | ||||||||||||
| raise ValueError(f"Cannot fuse norm of type {type(embedding)}") | ||||||||||||
|
|
||||||||||||
| with align_module_device(embedding): | ||||||||||||
| weight_dtype = embedding.weight.dtype | ||||||||||||
| weight = embedding.weight.to(PRECISION) | ||||||||||||
| new_weight = weight - weight.mean(dim=-1, keepdim=True) | ||||||||||||
| new_weight = new_weight.to(weight_dtype) | ||||||||||||
|
|
||||||||||||
| update_offload_parameter(embedding, "weight", new_weight) | ||||||||||||
| weight = embedding.weight.to(PRECISION) | ||||||||||||
| weight = weight - weight.mean(dim=-1, keepdim=True) | ||||||||||||
| update_offload_parameter(embedding, "weight", weight) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]): | ||||||||||||
|
|
@@ -46,15 +38,7 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]) | |||||||||||
|
|
||||||||||||
| for linear in linears: | ||||||||||||
| # NOTE: spinquant does this op in float64 | ||||||||||||
| exec_device = get_execution_device(norm) | ||||||||||||
| with align_module_device(norm, exec_device), align_module_device( | ||||||||||||
| linear, exec_device | ||||||||||||
| ): | ||||||||||||
| weight_dtype = linear.weight.dtype | ||||||||||||
| new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION) | ||||||||||||
| new_weight = new_weight.to(weight_dtype) | ||||||||||||
|
|
||||||||||||
| update_offload_parameter(linear, "weight", new_weight) | ||||||||||||
|
|
||||||||||||
| new_norm_weight = torch.ones_like(norm.weight, device="cpu") | ||||||||||||
| update_offload_parameter(norm, "weight", new_norm_weight) | ||||||||||||
| linear_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION) | ||||||||||||
| update_offload_parameter(linear, "weight", linear_weight) | ||||||||||||
|
Comment on lines
+41
to
+42
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the
Suggested change
|
||||||||||||
|
|
||||||||||||
| update_offload_parameter(norm, "weight", torch.ones_like(norm.weight)) | ||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,7 +2,8 @@ | |||||||||||||||||||
| from typing import TYPE_CHECKING | ||||||||||||||||||||
|
|
||||||||||||||||||||
| import torch | ||||||||||||||||||||
| from compressed_tensors.utils import disable_offloading, get_execution_device | ||||||||||||||||||||
| from compressed_tensors.offload import dispatch_model | ||||||||||||||||||||
| from compressed_tensors.utils import disable_offloading | ||||||||||||||||||||
| from torch.utils.data.dataloader import DataLoader | ||||||||||||||||||||
| from tqdm import tqdm | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -11,7 +12,6 @@ | |||||||||||||||||||
| from llmcompressor.pipelines.cache import IntermediatesCache | ||||||||||||||||||||
| from llmcompressor.pipelines.registry import CalibrationPipeline | ||||||||||||||||||||
| from llmcompressor.pipelines.sequential.helpers import ( | ||||||||||||||||||||
| dispatch_for_sequential, | ||||||||||||||||||||
| get_sequential_targets, | ||||||||||||||||||||
| trace_subgraphs, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
@@ -59,10 +59,6 @@ def __call__( | |||||||||||||||||||
| """ | ||||||||||||||||||||
| session = active_session() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # prepare model for sequential onloading | ||||||||||||||||||||
| dispatch_for_sequential(model) | ||||||||||||||||||||
| model_device = get_execution_device(model) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # prepare to trace subgraphs | ||||||||||||||||||||
| modifiers = session.lifecycle.recipe.modifiers | ||||||||||||||||||||
| sequential_targets = get_sequential_targets(modifiers, model, dataset_args) | ||||||||||||||||||||
|
|
@@ -73,6 +69,10 @@ def __call__( | |||||||||||||||||||
| subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore) | ||||||||||||||||||||
| num_subgraphs = len(subgraphs) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # prepare model for sequential onloading | ||||||||||||||||||||
| model_device = "cuda" if torch.cuda.is_available() else "cpu" | ||||||||||||||||||||
| dispatch_model(model, model_device) | ||||||||||||||||||||
|
Comment on lines
+73
to
+74
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous implementation in
Suggested change
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| LifecycleCallbacks.calibration_epoch_start() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # TODO: remove this to enable quantization aware calibration | ||||||||||||||||||||
|
|
||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
weighttensor is not cast back to its original data type after the centering operation. It remains asPRECISION(torch.float64), which could lead to increased memory usage and potential dtype mismatches in subsequent operations. It's recommended to restore the casting back to the original dtype before updating the parameter.