Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/accelerate/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,10 @@ def find_executable_batch_size(


>>> @find_executable_batch_size(starting_batch_size=128)
... def train(batch_size, model, optimizer):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change is unrelated. Seems to come from ruff, and I'm not sure why. I'm using the ruff version from the setup.py

... ...
... def train(batch_size, model, optimizer): ...


>>> train(model, optimizer)
... train(model, optimizer)
```
"""
if function is None:
Expand Down
36 changes: 23 additions & 13 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def set_module_tensor_to_device(
dtype: Optional[Union[str, torch.dtype]] = None,
fp16_statistics: Optional[torch.HalfTensor] = None,
tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
non_blocking: bool = False,
_empty_cache: bool = True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update the docstring also ? Also, maybe empty_cache or clear_cache ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc Thanks, updated! Could you take a look again?

):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
Expand Down Expand Up @@ -295,9 +297,9 @@ def set_module_tensor_to_device(

if dtype is None:
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
value = value.to(old_value.dtype)
value = value.to(old_value.dtype, non_blocking=non_blocking)
elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
value = value.to(dtype)
value = value.to(dtype, non_blocking=non_blocking)

device_quantization = None
with torch.no_grad():
Expand Down Expand Up @@ -326,15 +328,15 @@ def set_module_tensor_to_device(
if "xpu" in str(device) and not is_xpu_available():
raise ValueError(f'{device} is not available, you should use device="cpu" instead')
if value is None:
new_value = old_value.to(device)
new_value = old_value.to(device, non_blocking=non_blocking)
if dtype is not None and device in ["meta", torch.device("meta")]:
if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
new_value = new_value.to(dtype)
new_value = new_value.to(dtype, non_blocking=non_blocking)

if not is_buffer:
module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad)
elif isinstance(value, torch.Tensor):
new_value = value.to(device)
new_value = value.to(device, non_blocking=non_blocking)
else:
new_value = torch.tensor(value, device=device)
if device_quantization is not None:
Expand All @@ -347,24 +349,30 @@ def set_module_tensor_to_device(
if param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]:
if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32:
# downcast to fp16 if any - needed for 8bit serialization
new_value = new_value.to(torch.float16)
new_value = new_value.to(torch.float16, non_blocking=non_blocking)
# quantize module that are going to stay on the cpu so that we offload quantized weights
if device == "cpu" and param_cls.__name__ == "Int8Params":
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to("cpu")
new_value.CB = new_value.CB.to("cpu")
new_value.SCB = new_value.SCB.to("cpu")
else:
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device)
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(
device, non_blocking=non_blocking
)
elif param_cls.__name__ in ["QTensor", "QBitsTensor"]:
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(device)
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(
device, non_blocking=non_blocking
)
elif param_cls.__name__ in ["AffineQuantizedTensor"]:
new_value = new_value.to(device)
new_value = new_value.to(device, non_blocking=non_blocking)
else:
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(
device, non_blocking=non_blocking
)

module._parameters[tensor_name] = new_value
if fp16_statistics is not None:
module._parameters[tensor_name].SCB = fp16_statistics.to(device)
module._parameters[tensor_name].SCB = fp16_statistics.to(device, non_blocking=non_blocking)
del fp16_statistics
# as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight
if (
Expand All @@ -390,9 +398,11 @@ def set_module_tensor_to_device(
device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
if not getattr(module.weight, "quant_state", None) and device_index is not None:
module.weight = module.weight.cuda(device_index)

# clean pre and post forward hook
if device != "cpu":
clear_device_cache()
if _empty_cache:
if device != "cpu":
clear_device_cache()

# When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in
# order to avoid duplicating memory, see above.
Expand Down
Loading