-
Notifications
You must be signed in to change notification settings - Fork 31.7k
TP initialization module-by-module #35996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
5e198a0
190fd7d
b7aa37c
321f8ee
51f0aa0
2f7e20c
6ca1838
8c419c6
a3a55d0
9ab8539
0d8e55a
cb5d92b
636a388
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -767,6 +767,7 @@ def _load_state_dict_into_meta_model( | |
| keep_in_fp32_modules=None, | ||
| unexpected_keys=None, # passing `unexpected` for cleanup from quantization items | ||
| pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys | ||
| device_mesh=None, | ||
| ): | ||
| """ | ||
| This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its | ||
|
|
@@ -776,6 +777,8 @@ def _load_state_dict_into_meta_model( | |
| `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in | ||
| `bert.pooler.dense.weight` | ||
|
|
||
| It also initialize tensor parallelism for each module if needed. | ||
|
|
||
| """ | ||
|
|
||
| # XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model | ||
|
|
@@ -789,6 +792,12 @@ def _load_state_dict_into_meta_model( | |
|
|
||
| is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") | ||
|
|
||
| # we need this later to initialize tensor parallelism | ||
| if device_mesh is not None: | ||
| full_tp_plan = model.config.base_model_tp_plan | ||
| for submodule in model.modules(): | ||
| full_tp_plan.update(getattr(submodule, "_tp_plan", {})) | ||
|
|
||
| for param_name, param in state_dict.items(): | ||
| if param_name not in expected_keys: | ||
| continue | ||
|
|
@@ -892,6 +901,37 @@ def _load_state_dict_into_meta_model( | |
| setattr(module, tensor_name, value) | ||
| # TODO: consider removing used param_parts from state_dict before return | ||
|
|
||
| # In this case, let's parallelize the modules! | ||
| if device_mesh is not None: | ||
| # Immediate parent | ||
| split_parent_module_name = param_name.split(".")[:-1] | ||
| parent_module_name = ".".join(split_parent_module_name) | ||
| parent_module = model | ||
| for name in split_parent_module_name: | ||
| parent_module = getattr(parent_module, name) | ||
|
|
||
| # Check if we are part of the tp_plan | ||
| current_module_plan = None | ||
| for param, plan in full_tp_plan.items(): | ||
| # "*" are a placeholder for layer indices, so we replace them by "[0-9]+" in the regex pattern | ||
| pattern = param.replace("*", "[0-9]+") | ||
| if re.search(pattern, parent_module_name): | ||
| current_module_plan = plan | ||
| break | ||
|
|
||
|
Comment on lines
+915
to
+921
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to iterate over the full tp_plan, but we should be re-creating the key instead
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mmm agreed cost-wise, it's a tad of a waste! but no worries |
||
| # We can only apply the tp_plan after all parameters of the current module have been correctly initialized (e.g. | ||
| # if we have bias, we need both `weights` and `bias` of a nn.Linear to be initialized) | ||
| process_device = list(device_map.values())[0] | ||
| all_module_parameters_initialized = all( | ||
| m.device == process_device for m in parent_module.parameters(recurse=False) | ||
| ) and all(m.device == process_device for m in parent_module.buffers(recurse=False)) | ||
|
Comment on lines
+924
to
+927
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similarly this might be a tad bit costly for MOE for example / not necessarily needed.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately, the shards that are loaded are not necessarily in order, so we cannot rely on it in general... And we check it only for leafs in the state dict (i.e. the Linear/Embedding/Norm layers), so they have at most 2 or 3
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it checks that weights were properly loaded normally! Because otherwise it goes through the init loop |
||
| if current_module_plan is not None and all_module_parameters_initialized: | ||
| torch.distributed.tensor.parallel.parallelize_module( | ||
| parent_module, | ||
| device_mesh=device_mesh, | ||
| parallelize_plan=translate_to_torch_parallel_style(current_module_plan), | ||
| ) | ||
|
|
||
| return error_msgs, offload_index, state_dict_index | ||
|
|
||
|
|
||
|
|
@@ -3448,12 +3488,11 @@ def from_pretrained( | |
| ) | ||
|
|
||
| # We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple | ||
| # `device_map` pointing to the correct device. If we don't, torch will use the default device (index 0) for all | ||
| # childs processes at parallelization time, resulting in excessive memory usage on device 0 and OOMs. | ||
| # And temporarily setting the default device to current process rank result in the following error | ||
| # `torch.distributed.DistBackendError: Attempt to perform collective on tensor not on device passed to init_process_group` | ||
| tp_device = None | ||
| # `device_map` pointing to the correct device | ||
| device_mesh = None | ||
| if tp_plan is not None: | ||
| if not is_torch_greater_or_equal("2.5"): | ||
| raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.") | ||
| if not torch.distributed.is_initialized(): | ||
| raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.") | ||
|
|
||
|
|
@@ -3465,6 +3504,10 @@ def from_pretrained( | |
| # This is the easiest way to dispatch to the current process device | ||
| device_map = tp_device | ||
|
|
||
| # Assuming sharding the model onto the world | ||
| world_size = torch.distributed.get_world_size() | ||
| device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,)) | ||
|
|
||
| if is_fsdp_enabled(): | ||
| low_cpu_mem_usage = True | ||
|
|
||
|
|
@@ -3559,7 +3602,7 @@ def from_pretrained( | |
| if low_cpu_mem_usage is None: | ||
| low_cpu_mem_usage = True | ||
| elif not low_cpu_mem_usage: | ||
| raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") | ||
| raise ValueError("Passing along a `device_map` or a `tp_plan` requires `low_cpu_mem_usage=True`") | ||
|
|
||
| if low_cpu_mem_usage: | ||
| if is_deepspeed_zero3_enabled(): | ||
|
|
@@ -3568,7 +3611,7 @@ def from_pretrained( | |
| ) | ||
| elif not is_accelerate_available(): | ||
| raise ImportError( | ||
| f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" | ||
| f"Using `low_cpu_mem_usage=True`, a `device_map` or a `tp_plan` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" | ||
| ) | ||
|
|
||
| # handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation. | ||
|
|
@@ -4141,6 +4184,9 @@ def from_pretrained( | |
| # Let's make sure we don't run the init function of buffer modules | ||
| model = cls(config, *model_args, **model_kwargs) | ||
|
|
||
| if device_mesh is not None and not model.supports_tp_plan: | ||
| raise NotImplementedError("This model does not have a tensor parallel plan.") | ||
|
|
||
| # make sure we use the model's config since the __init__ call might have copied it | ||
| config = model.config | ||
|
|
||
|
|
@@ -4285,6 +4331,7 @@ def from_pretrained( | |
| keep_in_fp32_modules=keep_in_fp32_modules, | ||
| gguf_path=gguf_path, | ||
| weights_only=weights_only, | ||
| device_mesh=device_mesh, | ||
| ) | ||
|
|
||
| # make sure token embedding weights are still tied if needed | ||
|
|
@@ -4319,8 +4366,9 @@ def from_pretrained( | |
| ) | ||
| pass | ||
|
|
||
| # Dispatch model with hooks on all devices if necessary | ||
| if device_map is not None: | ||
| # Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly | ||
| # harm performances) | ||
| if device_map is not None and device_mesh is None: | ||
| device_map_kwargs = { | ||
| "device_map": device_map, | ||
| "offload_dir": offload_folder, | ||
|
|
@@ -4347,6 +4395,13 @@ def from_pretrained( | |
| if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): | ||
| dispatch_model(model, **device_map_kwargs) | ||
|
|
||
| # This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is | ||
| # not part of the state_dict (persistent=False) | ||
| if device_mesh is not None: | ||
| for buffer in model.buffers(): | ||
| if buffer.device != tp_device: | ||
| buffer.data = buffer.to(tp_device) | ||
|
Comment on lines
+4402
to
+4403
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, remember that no we pass the cos and sin as input to all layers so to are passed |
||
|
|
||
| if hf_quantizer is not None: | ||
| hf_quantizer.postprocess_model(model, config=config) | ||
| model.hf_quantizer = hf_quantizer | ||
|
|
@@ -4369,16 +4424,6 @@ def from_pretrained( | |
| } | ||
| return model, loading_info | ||
|
|
||
| if tp_plan is not None: | ||
| assert tp_device is not None, "tp_device not set!" | ||
| if not model.supports_tp_plan: | ||
| raise NotImplementedError("This model does not have a tensor parallel plan.") | ||
| # Assuming sharding the model onto the world | ||
| world_size = torch.distributed.get_world_size() | ||
| device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,)) | ||
| # Apply Tensor Parallelism | ||
| model.tensor_parallel(device_mesh) | ||
|
|
||
| return model | ||
|
|
||
| @staticmethod | ||
|
|
@@ -4472,6 +4517,7 @@ def _load_pretrained_model( | |
| keep_in_fp32_modules=None, | ||
| gguf_path=None, | ||
| weights_only=True, | ||
| device_mesh=None, | ||
| ): | ||
| is_safetensors = False | ||
| is_quantized = hf_quantizer is not None | ||
|
|
@@ -4771,6 +4817,7 @@ def _find_mismatched_keys( | |
| is_safetensors=is_safetensors, | ||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||
| unexpected_keys=unexpected_keys, | ||
| device_mesh=device_mesh, | ||
| ) | ||
| else: | ||
| # Sharded checkpoint or whole but low_cpu_mem_usage==True | ||
|
|
@@ -4860,6 +4907,7 @@ def _find_mismatched_keys( | |
| is_safetensors=is_safetensors, | ||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||
| unexpected_keys=unexpected_keys, | ||
| device_mesh=device_mesh, | ||
| ) | ||
| error_msgs += new_error_msgs | ||
| else: | ||
|
|
@@ -5137,7 +5185,12 @@ def supports_tp_plan(self): | |
|
|
||
| def tensor_parallel(self, device_mesh): | ||
| """ | ||
| Tensor parallelize the model across the given device mesh. | ||
| Tensor parallelize the model across the given device mesh. This function is a helper to be called after the model | ||
| was already loaded in memory, note however that this means that each process will first initialize the whole model, | ||
| then parallelize it accross devices. Thus there is a huge waste of GPU memory, and this can lead to OOM at loading time. | ||
|
|
||
| Calling `from_pretrained(..., tp_plan="auto")` is prefered, and will parallelize module-by-module during initialization, | ||
| so that the expected per-device memory spike at loading time is not larger than the final model size on each device. | ||
|
|
||
| Args: | ||
| device_mesh (`torch.distributed.DeviceMesh`): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,17 +80,13 @@ def test_loading_memory_consumption(self): | |
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto") | ||
| torch.distributed.barrier() | ||
|
|
||
| # The expected full model memory footprint | ||
| expected_model_memory = 16 | ||
| # The expected model memory footprint. We add 1 as not all the modules are split (e.g. the embeddings) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's add something related to this in the test
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it currently checks that we do not use more than the expected memory divided by world size , i.e., no more than 5 GiB per GPU on my tests on DGX for Llama 8B (expected memory per device = a bit more than 4 GiB) |
||
| expected_model_memory_per_device = (16 / world_size) + 1 | ||
| overhead_factor = 1.2 | ||
|
|
||
| # Assert we did not use more than the full model expected memory (with some overhead) | ||
| if not torch.cuda.max_memory_allocated(device) / 1024**3 < expected_model_memory * overhead_factor: | ||
| raise ValueError("Loading the model used more than the full model size") | ||
|
|
||
| # Assert we correctly handled the sharding between devices | ||
| if not torch.cuda.memory_allocated(device) / 1024**3 < (expected_model_memory / world_size) * overhead_factor: | ||
| raise ValueError("Each model shard is larger than what is expected.") | ||
| # Check that we do not use more than the expected sharded size during initialization | ||
| if torch.cuda.max_memory_allocated(device) / 1024**3 > expected_model_memory_per_device * overhead_factor: | ||
| raise ValueError("Loading the model used more than the expected fraction of model size per device") | ||
|
|
||
| torch.distributed.barrier() | ||
| torch.distributed.destroy_process_group() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we should only do this for
PreTrainedModelsno?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I figured that maybe it would be a bit more future-proof to iterate over all modules (it's not costly) -- but can be changed for sure!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's only do Prtrained for now!