Skip to content
295 changes: 157 additions & 138 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from huggingface_hub import split_torch_state_dict_into_shards
from packaging import version
from torch import Tensor, nn
from torch.distributed.tensor import DTensor, Shard
from torch.distributions import constraints
from torch.nn import CrossEntropyLoss, Identity
from torch.utils.checkpoint import checkpoint
Expand All @@ -67,7 +66,6 @@
translate_to_torch_parallel_style,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
from .safetensors_conversion import auto_conversion
from .utils import (
ACCELERATE_MIN_VERSION,
Expand Down Expand Up @@ -181,6 +179,9 @@ def is_local_dist_rank_0():
if is_peft_available():
from .utils import find_adapter_config_file

if is_torch_greater_or_equal("2.5"):
from torch.distributed.tensor import DTensor, Shard

SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")

TORCH_INIT_FUNCTIONS = {
Expand Down Expand Up @@ -702,7 +703,7 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
return shared_tensors, identical


def find_submodule_and_param_name(model, long_key, start_prefix):
def find_submodule_and_param_name(model, long_key, start_prefix=""):
"""
A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed
from the start of the key
Expand Down Expand Up @@ -767,7 +768,6 @@ def _load_state_dict_into_meta_model(
is_safetensors=False,
keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
device_mesh=None,
shard_file=None,
):
Expand All @@ -786,145 +786,152 @@ def _load_state_dict_into_meta_model(
if device_map is not None and device_map.get("", None) is not None:
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]

with safe_open(shard_file, framework="pt", device=tensor_device) as file_pointer:
error_msgs = []
# we need this later to initialize tensor parallelism
if device_mesh is not None:
full_tp_plan = model.config.base_model_tp_plan
for submodule in model.modules():
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))

is_quantized = hf_quantizer is not None
file_pointer = None
bin_state_dict = None
if shard_file.endswith(".safetensors"):
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
else:
bin_state_dict = load_state_dict(shard_file, map_location="cpu")

is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
error_msgs = []

# we need this later to initialize tensor parallelism
if device_mesh is not None:
full_tp_plan = model.config.base_model_tp_plan
for submodule in model.modules():
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))

for serialized_param_name, empty_param in state_dict.items():
# param_name is the raw, serialized name
# new_param_name is the model's equivalent
module_name, _ = model.rename_key(serialized_param_name)
if module_name not in expected_keys:
continue
layer, param_type = module_name.rsplit(".", 1)

# param name needs to stay untouched as it's in the file
param = file_pointer.get_slice(serialized_param_name)
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
param_casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
if (
keep_in_fp32_modules is not None
and keep_in_fp32_modules.search(module_name)
and dtype == torch.float16
):
param_casting_dtype = torch.float32
else:
param_casting_dtype = dtype
is_quantized = hf_quantizer is not None

if device_mesh is not None: # In this case, the param is already on the correct device!
try:
module_to_tp: torch.nn.Module = model.get_submodule(layer)
except Exception:
raise ValueError(
"The config tp plan is wrong because the layer is not a liner layer, nor an embedding"
)
current_module_plan = None
full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+")
if plan := re.search(full_tp_plan_, module_name):
match = re.sub("[0-9]+", "*", plan[0])
current_module_plan = full_tp_plan[match]

if current_module_plan is not None:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
rank = tensor_device
row, col = empty_param.shape
if "rowwise" == current_module_plan:
param = param[:, rank * (col // device_mesh.size()) : (rank + 1) * (col // device_mesh.size())]
shard = Shard(1)
tp_layer.desired_input_layouts = (Shard(-1),)
elif "colwise" == current_module_plan:
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
shard = Shard(0)
else:
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
shard = Shard(0)
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
param = param.to(param_casting_dtype)
local_parameter = DTensor.from_local(
param,
device_mesh=device_mesh,
placements=[shard] * device_mesh.ndim,
)
if isinstance(module_to_tp.weight, nn.Parameter):
local_parameter = torch.nn.Parameter(local_parameter)
module_to_tp.weight = local_parameter
input_fn = partial(
tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts
)
output_fn = partial(
tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output
)
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
else:
module_to_tp.load_state_dict({param_type: param[:]}, False, True)
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")

for serialized_param_name, empty_param in state_dict.items():
# serialized_param_name is the raw, serialized name
# fixed_param_name is the model's equivalent
fixed_param_name, _ = model.rename_key(serialized_param_name)

if fixed_param_name not in expected_keys:
continue

# we need to use serialized_param_name as file pointer is untouched
param = (
file_pointer.get_slice(serialized_param_name)
if shard_file.endswith(".safetensors")
else bin_state_dict[serialized_param_name]
)
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
param_casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn

if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
if (
keep_in_fp32_modules is not None
and keep_in_fp32_modules.search(fixed_param_name)
and dtype == torch.float16
):
param_casting_dtype = torch.float32
else:
if device_map is None:
param_device = "cpu"
param_casting_dtype = dtype

if device_mesh is not None: # In this case, the param is already on the correct device!
module_to_tp, param_type = find_submodule_and_param_name(model, fixed_param_name)
current_module_plan = None
full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+")
if plan := re.search(full_tp_plan_, fixed_param_name):
match = re.sub("[0-9]+", "*", plan[0])
current_module_plan = full_tp_plan[match]

if current_module_plan is not None:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
rank = tensor_device
row, col = empty_param.shape
if "rowwise" == current_module_plan:
param = param[:, rank * (col // device_mesh.size()) : (rank + 1) * (col // device_mesh.size())]
shard = Shard(1)
tp_layer.desired_input_layouts = (Shard(-1),)
elif "colwise" == current_module_plan:
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
shard = Shard(0)
else:
module_name = module_name.rsplit(".", 1)[0]
device_map_regex = "|".join(device_map.keys())
module_layer = re.search(device_map_regex, module_name)
if module_name == "" or device_map_regex is None:
raise ValueError(
f"`device_map` is used, but {module_name} doesn't have any device set. {device_map}"
)
else:
param_device = device_map[module_layer.group()]

if param_device == "disk" and not is_safetensors:
offload_index = offload_weight(param[:], module_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param[:], module_name, state_dict_folder, state_dict_index)
elif (
not is_quantized
or (not hf_quantizer.requires_parameters_quantization)
or (
not hf_quantizer.check_quantized_param(
model, param, module_name, state_dict, param_device=param_device, device_map=device_map
)
)
):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
module = model.get_submodule(layer)
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
param = param[:].to(param_casting_dtype)
module.load_state_dict(
{param_type: param[:].to(param_device)},
False,
True,
)
else:
hf_quantizer.create_quantized_param(
model, param[:], module_name, param_device, state_dict, unexpected_keys
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
shard = Shard(0)
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
param = param.to(param_casting_dtype)
local_parameter = DTensor.from_local(
param,
device_mesh=device_mesh,
placements=[shard] * device_mesh.ndim,
)
if isinstance(module_to_tp.weight, nn.Parameter):
local_parameter = torch.nn.Parameter(local_parameter)
module_to_tp.weight = local_parameter
input_fn = partial(tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts)
output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output)
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
else:
module_to_tp.load_state_dict({param_type: param[:]}, strict=False, assign=True)

else:
if device_map is None:
param_device = "cpu"
else:
module_name = fixed_param_name
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
if module_name == "" and "" not in device_map:
raise ValueError(f"{fixed_param_name} doesn't have any device set.")
param_device = device_map[module_name]

if param_device == "disk":
if not is_safetensors:
offload_index = offload_weight(param[:], fixed_param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param[:], fixed_param_name, state_dict_folder, state_dict_index)
elif (
not is_quantized
or (not hf_quantizer.requires_parameters_quantization)
or (
not hf_quantizer.check_quantized_param(
model,
param,
fixed_param_name,
state_dict,
param_device=param_device,
device_map=device_map,
)
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU
# in comparison to the sharded model across GPUs.
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
module, tensor_name = get_module_from_name(model, module_name)
value = getattr(module, tensor_name)
param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
val_kwargs = {}
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
val_kwargs["requires_grad"] = False
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
setattr(module, tensor_name, value)
)
):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
module, param_type = find_submodule_and_param_name(model, fixed_param_name)
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
param = param[:].to(param_casting_dtype)
module.load_state_dict(
{param_type: param[:].to(param_device)},
strict=False,
assign=True,
)
else:
hf_quantizer.create_quantized_param(
model, param[:], fixed_param_name, param_device, state_dict, unexpected_keys
)
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU
# in comparison to the sharded model across GPUs.
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
module, param_type = find_submodule_and_param_name(model, fixed_param_name)
value = getattr(module, param_type)
param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
val_kwargs = {}
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
val_kwargs["requires_grad"] = False
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
setattr(module, param_type, value)
if file_pointer is not None:
file_pointer.__exit__(None, None, None)

return error_msgs, offload_index, state_dict_index

Expand Down Expand Up @@ -4966,7 +4973,7 @@ def _load_pretrained_model(
ignore_mismatched_sizes,
prefix,
)
if low_cpu_mem_usage and shard_file.endswith(".safetensors"):
if low_cpu_mem_usage:
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
Expand Down Expand Up @@ -5840,18 +5847,30 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
accelerator_device_map = {
param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"]
}
if not len(accelerator_device_map):
return

parameter_count = defaultdict(lambda: 0)
allocation_factor = 1
if torch.distributed.is_initialized() or len(set(accelerator_device_map.values())) >= 2:
allocation_factor = 2

for param_name, device in accelerator_device_map.items():
try:
param = model.get_parameter(param_name)
except AttributeError:
param = model.get_buffer(param_name)
parameter_count[device] += int(math.prod(param.shape) * 2)
parameter_count[device] += int(math.prod(param.shape) * allocation_factor)

dtype = dtype if dtype is not None else torch.float32
# calling max_memory will create a tensor thus creating a bit of overhead (aten::empty_strided)
# max_memory = get_max_memory()

# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items():
_ = torch.empty(int(param_count), dtype=dtype, device=device, requires_grad=False)
# allocate only if we have enough memory
# if max_memory[device.index] > param_count * dtype_byte_size(dtype):
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)


def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
Expand Down