Skip to content
Merged
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 65 additions & 33 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
translate_to_torch_parallel_style,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
from .safetensors_conversion import auto_conversion
from .utils import (
ACCELERATE_MIN_VERSION,
Expand Down Expand Up @@ -536,11 +537,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
is_quantized: bool = False,
map_location: Optional[Union[str, torch.device]] = "meta",
map_location: Optional[Union[str, torch.device]] = "cpu",
weights_only: bool = True,
):
"""
Reads a `safetensor` or a `.bin` checkpoint file into `meta` if requested.
Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
"""
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
with safe_open(checkpoint_file, framework="pt") as f:
Expand Down Expand Up @@ -771,6 +772,7 @@ def _load_state_dict_into_meta_model(
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
device_mesh=None,
shard_file=None,
weights_only=True,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
Expand Down Expand Up @@ -800,7 +802,15 @@ def _load_state_dict_into_meta_model(
if shard_file.endswith(".safetensors"):
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
else:
bin_state_dict = load_state_dict(shard_file, map_location="cpu")
map_location = "cpu"
if (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
bin_state_dict = load_state_dict(shard_file, map_location=map_location, weights_only=weights_only)

error_msgs = []

Expand All @@ -822,23 +832,36 @@ def _load_state_dict_into_meta_model(
if shard_file.endswith(".safetensors")
else bin_state_dict[serialized_param_name]
)

# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29

old_param = model
splits = fixed_param_name.split(".")
for split in splits:
# We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
old_param = getattr(old_param, split, None)
if old_param is None:
break

if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None

# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
param_casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn

if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
if (
keep_in_fp32_modules is not None
and keep_in_fp32_modules.search(fixed_param_name)
and dtype == torch.float16
):
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(fixed_param_name):
param_casting_dtype = torch.float32
else:
elif dtype is not None:
param_casting_dtype = dtype
elif old_param is not None:
param_casting_dtype = old_param.dtype

if device_mesh is not None: # In this case, the param is already on the correct device!
module_to_tp, param_type = find_submodule_and_param_name(model, fixed_param_name)
module_to_tp, param_type = get_module_from_name(model, fixed_param_name)
current_module_plan = None
full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+")
if plan := re.search(full_tp_plan_, fixed_param_name):
Expand All @@ -859,8 +882,10 @@ def _load_state_dict_into_meta_model(
else:
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
shard = Shard(0)
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
if param_casting_dtype is not None:
param = param.to(param_casting_dtype)
if old_param.is_contiguous():
param = param.contiguous()
local_parameter = DTensor.from_local(
param,
device_mesh=device_mesh,
Expand All @@ -873,9 +898,18 @@ def _load_state_dict_into_meta_model(
output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output)
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
else:
module_to_tp.load_state_dict({param_type: param[:]}, strict=False, assign=True)
param = param[:]
Copy link
Contributor

@fxmarty-amd fxmarty-amd Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc why is param = param[:] needed?

edit - ok, this is for safetensors. Unfortuantely safetensors get_slice does not play well with 0-dim tensors :( huggingface/safetensors#380

if old_param is not None and old_param.is_contiguous():
param = param.contiguous()
module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)

else:
param = param[:]
if param_casting_dtype is not None:
param = param.to(param_casting_dtype)
if old_param is not None and old_param.is_contiguous():
param = param.contiguous()

if device_map is None:
param_device = "cpu"
else:
Expand All @@ -887,9 +921,9 @@ def _load_state_dict_into_meta_model(

if param_device == "disk":
if not is_safetensors:
offload_index = offload_weight(param[:], fixed_param_name, offload_folder, offload_index)
offload_index = offload_weight(param, fixed_param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param[:], fixed_param_name, state_dict_folder, state_dict_index)
state_dict_index = offload_weight(param, fixed_param_name, state_dict_folder, state_dict_index)
elif (
not is_quantized
or (not hf_quantizer.requires_parameters_quantization)
Expand All @@ -906,23 +940,25 @@ def _load_state_dict_into_meta_model(
):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
module, param_type = find_submodule_and_param_name(model, fixed_param_name)
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
param = param[:].to(param_casting_dtype)
module, param_type = get_module_from_name(model, fixed_param_name)
print(model)
print(fixed_param_name)
print(param)
print(module)
module.load_state_dict(
{param_type: param[:].to(param_device)},
{param_type: param.to(param_device)},
strict=False,
assign=True,
)
else:
hf_quantizer.create_quantized_param(
model, param[:], fixed_param_name, param_device, state_dict, unexpected_keys
model, param, fixed_param_name, param_device, state_dict, unexpected_keys
)
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU
# in comparison to the sharded model across GPUs.
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
module, param_type = find_submodule_and_param_name(model, fixed_param_name)
module, param_type = get_module_from_name(model, fixed_param_name)
value = getattr(module, param_type)
param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0():
Expand Down Expand Up @@ -4203,7 +4239,9 @@ def from_pretrained(
elif not is_sharded:
torch_dtype = get_state_dict_dtype(state_dict)
else:
one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only)
one_state_dict = load_state_dict(
resolved_archive_file[0], map_location="meta", weights_only=weights_only
)
torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory
logger.info(
Expand Down Expand Up @@ -4848,7 +4886,7 @@ def _load_pretrained_model(
else:
folder = None

model.expected_keys = expected_keys
model_to_load.expected_keys = expected_keys
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are calling _fix_state_dict_keys_on_load on model_to_load

if device_map is not None:
expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
if hf_quantizer is None:
Expand Down Expand Up @@ -4907,6 +4945,7 @@ def _load_pretrained_model(
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
resolved_archive_file=resolved_archive_file,
weights_only=weights_only,
)
else:
# We need to read the state dict as it is meta otherwise
Expand Down Expand Up @@ -4957,16 +4996,8 @@ def _load_pretrained_model(
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
if shard_file in disk_only_shard_files:
continue
map_location = None
if (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only
)

# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
Expand Down Expand Up @@ -5006,6 +5037,7 @@ def _load_pretrained_model(
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
shard_file=shard_file,
weights_only=weights_only,
)
error_msgs += new_error_msgs
else:
Expand Down
Loading