Skip to content
93 changes: 73 additions & 20 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,7 @@ def _load_state_dict_into_meta_model(
keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
device_mesh=None,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
Expand All @@ -776,6 +777,8 @@ def _load_state_dict_into_meta_model(
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`

It also initialize tensor parallelism for each module if needed.

"""

# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
Expand All @@ -789,6 +792,12 @@ def _load_state_dict_into_meta_model(

is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")

# we need this later to initialize tensor parallelism
if device_mesh is not None:
full_tp_plan = model.config.base_model_tp_plan
for submodule in model.modules():
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
Comment on lines +798 to +799
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should only do this for PreTrainedModels no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured that maybe it would be a bit more future-proof to iterate over all modules (it's not costly) -- but can be changed for sure!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's only do Prtrained for now!


for param_name, param in state_dict.items():
if param_name not in expected_keys:
continue
Expand Down Expand Up @@ -892,6 +901,37 @@ def _load_state_dict_into_meta_model(
setattr(module, tensor_name, value)
# TODO: consider removing used param_parts from state_dict before return

# In this case, let's parallelize the modules!
if device_mesh is not None:
# Immediate parent
split_parent_module_name = param_name.split(".")[:-1]
parent_module_name = ".".join(split_parent_module_name)
parent_module = model
for name in split_parent_module_name:
parent_module = getattr(parent_module, name)

# Check if we are part of the tp_plan
current_module_plan = None
for param, plan in full_tp_plan.items():
# "*" are a placeholder for layer indices, so we replace them by "[0-9]+" in the regex pattern
pattern = param.replace("*", "[0-9]+")
if re.search(pattern, parent_module_name):
current_module_plan = plan
break

Comment on lines +915 to +921
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to iterate over the full tp_plan, but we should be re-creating the key instead

Copy link
Member Author

@Cyrilvallez Cyrilvallez Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tp_plan does not contain the full module names (usually it starts with "layers"), so to be general it's much easier to iterate over the keys instead of starting from the module name and trying to get the key of the tp_plan (because the prefixes of the tp_plan keys may change). Once again it's not costly at all since the tp_plan is very small

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm agreed cost-wise, it's a tad of a waste! but no worries

# We can only apply the tp_plan after all parameters of the current module have been correctly initialized (e.g.
# if we have bias, we need both `weights` and `bias` of a nn.Linear to be initialized)
process_device = list(device_map.values())[0]
all_module_parameters_initialized = all(
m.device == process_device for m in parent_module.parameters(recurse=False)
) and all(m.device == process_device for m in parent_module.buffers(recurse=False))
Comment on lines +924 to +927
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly this might be a tad bit costly for MOE for example / not necessarily needed.
We can either:

  • maybe load for the previous layer? (so layer 1 loads layer 0 this way it's always after all bias are loaded?)
  • check is_hf_initialized as I think it should hold info about everything being initialized
    TLDR let's avoid loops

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, the shards that are loaded are not necessarily in order, so we cannot rely on it in general... And we check it only for leafs in the state dict (i.e. the Linear/Embedding/Norm layers), so they have at most 2 or 3 parameters(), so not much of an overhead I think. It does not look like we can use is_hf_initialized here (from what I understand it checks that the weights were created, not that the correct state_dict was loaded, and then dispatched to correct device)
In any way, if we did not specify tp_plan="auto", all of it is completely skipped

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it checks that weights were properly loaded normally! Because otherwise it goes through the init loop

if current_module_plan is not None and all_module_parameters_initialized:
torch.distributed.tensor.parallel.parallelize_module(
parent_module,
device_mesh=device_mesh,
parallelize_plan=translate_to_torch_parallel_style(current_module_plan),
)

return error_msgs, offload_index, state_dict_index


Expand Down Expand Up @@ -3448,12 +3488,11 @@ def from_pretrained(
)

# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
# `device_map` pointing to the correct device. If we don't, torch will use the default device (index 0) for all
# childs processes at parallelization time, resulting in excessive memory usage on device 0 and OOMs.
# And temporarily setting the default device to current process rank result in the following error
# `torch.distributed.DistBackendError: Attempt to perform collective on tensor not on device passed to init_process_group`
tp_device = None
# `device_map` pointing to the correct device
device_mesh = None
if tp_plan is not None:
if not is_torch_greater_or_equal("2.5"):
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
if not torch.distributed.is_initialized():
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")

Expand All @@ -3465,6 +3504,10 @@ def from_pretrained(
# This is the easiest way to dispatch to the current process device
device_map = tp_device

# Assuming sharding the model onto the world
world_size = torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))

if is_fsdp_enabled():
low_cpu_mem_usage = True

Expand Down Expand Up @@ -3559,7 +3602,7 @@ def from_pretrained(
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
elif not low_cpu_mem_usage:
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
raise ValueError("Passing along a `device_map` or a `tp_plan` requires `low_cpu_mem_usage=True`")

if low_cpu_mem_usage:
if is_deepspeed_zero3_enabled():
Expand All @@ -3568,7 +3611,7 @@ def from_pretrained(
)
elif not is_accelerate_available():
raise ImportError(
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
f"Using `low_cpu_mem_usage=True`, a `device_map` or a `tp_plan` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)

# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
Expand Down Expand Up @@ -4141,6 +4184,9 @@ def from_pretrained(
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs)

if device_mesh is not None and not model.supports_tp_plan:
raise NotImplementedError("This model does not have a tensor parallel plan.")

# make sure we use the model's config since the __init__ call might have copied it
config = model.config

Expand Down Expand Up @@ -4285,6 +4331,7 @@ def from_pretrained(
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
device_mesh=device_mesh,
)

# make sure token embedding weights are still tied if needed
Expand Down Expand Up @@ -4319,8 +4366,9 @@ def from_pretrained(
)
pass

# Dispatch model with hooks on all devices if necessary
if device_map is not None:
# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
# harm performances)
if device_map is not None and device_mesh is None:
device_map_kwargs = {
"device_map": device_map,
"offload_dir": offload_folder,
Expand All @@ -4347,6 +4395,13 @@ def from_pretrained(
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
dispatch_model(model, **device_map_kwargs)

# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
# not part of the state_dict (persistent=False)
if device_mesh is not None:
for buffer in model.buffers():
if buffer.device != tp_device:
buffer.data = buffer.to(tp_device)
Comment on lines +4402 to +4403
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, remember that no we pass the cos and sin as input to all layers so to are passed


if hf_quantizer is not None:
hf_quantizer.postprocess_model(model, config=config)
model.hf_quantizer = hf_quantizer
Expand All @@ -4369,16 +4424,6 @@ def from_pretrained(
}
return model, loading_info

if tp_plan is not None:
assert tp_device is not None, "tp_device not set!"
if not model.supports_tp_plan:
raise NotImplementedError("This model does not have a tensor parallel plan.")
# Assuming sharding the model onto the world
world_size = torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
# Apply Tensor Parallelism
model.tensor_parallel(device_mesh)

return model

@staticmethod
Expand Down Expand Up @@ -4472,6 +4517,7 @@ def _load_pretrained_model(
keep_in_fp32_modules=None,
gguf_path=None,
weights_only=True,
device_mesh=None,
):
is_safetensors = False
is_quantized = hf_quantizer is not None
Expand Down Expand Up @@ -4771,6 +4817,7 @@ def _find_mismatched_keys(
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
Expand Down Expand Up @@ -4860,6 +4907,7 @@ def _find_mismatched_keys(
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)
error_msgs += new_error_msgs
else:
Expand Down Expand Up @@ -5137,7 +5185,12 @@ def supports_tp_plan(self):

def tensor_parallel(self, device_mesh):
"""
Tensor parallelize the model across the given device mesh.
Tensor parallelize the model across the given device mesh. This function is a helper to be called after the model
was already loaded in memory, note however that this means that each process will first initialize the whole model,
then parallelize it accross devices. Thus there is a huge waste of GPU memory, and this can lead to OOM at loading time.

Calling `from_pretrained(..., tp_plan="auto")` is prefered, and will parallelize module-by-module during initialization,
so that the expected per-device memory spike at loading time is not larger than the final model size on each device.

Args:
device_mesh (`torch.distributed.DeviceMesh`):
Expand Down
14 changes: 5 additions & 9 deletions tests/tp/test_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,13 @@ def test_loading_memory_consumption(self):
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto")
torch.distributed.barrier()

# The expected full model memory footprint
expected_model_memory = 16
# The expected model memory footprint. We add 1 as not all the modules are split (e.g. the embeddings)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add something related to this in the test

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it currently checks that we do not use more than the expected memory divided by world size , i.e., no more than 5 GiB per GPU on my tests on DGX for Llama 8B (expected memory per device = a bit more than 4 GiB)

expected_model_memory_per_device = (16 / world_size) + 1
overhead_factor = 1.2

# Assert we did not use more than the full model expected memory (with some overhead)
if not torch.cuda.max_memory_allocated(device) / 1024**3 < expected_model_memory * overhead_factor:
raise ValueError("Loading the model used more than the full model size")

# Assert we correctly handled the sharding between devices
if not torch.cuda.memory_allocated(device) / 1024**3 < (expected_model_memory / world_size) * overhead_factor:
raise ValueError("Each model shard is larger than what is expected.")
# Check that we do not use more than the expected sharded size during initialization
if torch.cuda.max_memory_allocated(device) / 1024**3 > expected_model_memory_per_device * overhead_factor:
raise ValueError("Loading the model used more than the expected fraction of model size per device")

torch.distributed.barrier()
torch.distributed.destroy_process_group()
Expand Down