-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Simplify Tensor Parallel implementation with PyTorch TP #34184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
e60fb87
fd7f7c7
9224cab
79cc524
a2934b3
a8fc418
e84a388
396d158
7b346b5
d60679b
dda058a
12fbbe7
02c8c39
073c521
db6e5ee
5bb294e
290a7f1
bd2e89c
4892cef
9648f31
93ba283
73524c9
f312e55
ca93bdb
dc2672f
1e27d6f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -55,6 +55,7 @@ | |||||
| prune_conv1d_layer, | ||||||
| prune_layer, | ||||||
| prune_linear_layer, | ||||||
| translate_to_torch_parallel_style, | ||||||
| ) | ||||||
| from .quantizers import AutoHfQuantizer, HfQuantizer | ||||||
| from .quantizers.quantizers_utils import get_module_from_name | ||||||
|
|
@@ -1398,6 +1399,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix | |||||
| # Has support for a `QuantoQuantizedCache` instance as `past_key_values` | ||||||
| _supports_quantized_cache = False | ||||||
|
|
||||||
| # A tensor parallel plan to be applied to the model when TP is enabled. For | ||||||
| # top-level models, this attribute is currently defined in respective model | ||||||
| # code. For base models, this attribute comes from | ||||||
| # `config.base_model_tp_plan` during `post_init`. | ||||||
| _tp_plan = None | ||||||
|
|
||||||
| @property | ||||||
| def dummy_inputs(self) -> Dict[str, torch.Tensor]: | ||||||
| """ | ||||||
|
|
@@ -1442,6 +1449,9 @@ def post_init(self): | |||||
| """ | ||||||
| self.init_weights() | ||||||
| self._backward_compatibility_gradient_checkpointing() | ||||||
| # If current model is a base model, attach `base_model_tp_plan` from config | ||||||
| if self.base_model is self: | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. feels simpler
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, the reason for
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, but
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That's not always the case. For example,
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, indeed. I thiught we should enforce TP plan definition for all classes to avoid user errors but its fine like this! |
||||||
| self._tp_plan = self.config.base_model_tp_plan | ||||||
|
|
||||||
| def dequantize(self): | ||||||
| """ | ||||||
|
|
@@ -3472,6 +3482,11 @@ def from_pretrained( | |||||
| # Cache path to the GGUF file | ||||||
| gguf_path = None | ||||||
|
|
||||||
| tp_plan = kwargs.pop("tp_plan", None) | ||||||
| if tp_plan is not None and tp_plan != "auto": | ||||||
| # TODO: we can relax this check when we support taking tp_plan from a json file, for example. | ||||||
| raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.") | ||||||
|
|
||||||
| if is_fsdp_enabled(): | ||||||
| low_cpu_mem_usage = True | ||||||
|
|
||||||
|
|
@@ -4073,6 +4088,7 @@ def from_pretrained( | |||||
|
|
||||||
| # Instantiate model. | ||||||
| init_contexts = [no_init_weights(_enable=_fast_init)] | ||||||
| tp_device = None | ||||||
|
|
||||||
| if is_deepspeed_zero3_enabled() and not is_quantized: | ||||||
| import deepspeed | ||||||
|
|
@@ -4085,6 +4101,17 @@ def from_pretrained( | |||||
| f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" | ||||||
| ) | ||||||
| init_contexts.append(init_empty_weights()) | ||||||
| elif tp_plan is not None: | ||||||
| if not torch.distributed.is_initialized(): | ||||||
| raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.") | ||||||
|
|
||||||
| # Get device type (e.g. "cuda") | ||||||
| device_type = torch.distributed.distributed_c10d._device_capability()[0] | ||||||
| # Get torch device module (e.g. torch.cuda) based on device type | ||||||
| device_module = torch.get_device_module(device_type) | ||||||
| # Get device with index assuming equal number of devices per host | ||||||
| tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count()) | ||||||
| init_contexts.append(tp_device) | ||||||
|
|
||||||
| config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. | ||||||
| if not getattr(config, "_attn_implementation_autoset", False): | ||||||
|
|
@@ -4215,32 +4242,38 @@ def from_pretrained( | |||||
| if dtype_orig is not None: | ||||||
| torch.set_default_dtype(dtype_orig) | ||||||
|
|
||||||
| ( | ||||||
| model, | ||||||
| missing_keys, | ||||||
| unexpected_keys, | ||||||
| mismatched_keys, | ||||||
| offload_index, | ||||||
| error_msgs, | ||||||
| ) = cls._load_pretrained_model( | ||||||
| model, | ||||||
| state_dict, | ||||||
| loaded_state_dict_keys, # XXX: rename? | ||||||
| resolved_archive_file, | ||||||
| pretrained_model_name_or_path, | ||||||
| ignore_mismatched_sizes=ignore_mismatched_sizes, | ||||||
| sharded_metadata=sharded_metadata, | ||||||
| _fast_init=_fast_init, | ||||||
| low_cpu_mem_usage=low_cpu_mem_usage, | ||||||
| device_map=device_map, | ||||||
| offload_folder=offload_folder, | ||||||
| offload_state_dict=offload_state_dict, | ||||||
| dtype=torch_dtype, | ||||||
| hf_quantizer=hf_quantizer, | ||||||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||||||
| gguf_path=gguf_path, | ||||||
| weights_only=weights_only, | ||||||
| ) | ||||||
| load_contexts = [] | ||||||
| # Make sure we load onto targeted device | ||||||
| if tp_device is not None: | ||||||
| load_contexts.append(tp_device) | ||||||
|
|
||||||
| with ContextManagers(load_contexts): | ||||||
| ( | ||||||
| model, | ||||||
| missing_keys, | ||||||
| unexpected_keys, | ||||||
| mismatched_keys, | ||||||
| offload_index, | ||||||
| error_msgs, | ||||||
| ) = cls._load_pretrained_model( | ||||||
| model, | ||||||
| state_dict, | ||||||
| loaded_state_dict_keys, # XXX: rename? | ||||||
| resolved_archive_file, | ||||||
| pretrained_model_name_or_path, | ||||||
| ignore_mismatched_sizes=ignore_mismatched_sizes, | ||||||
| sharded_metadata=sharded_metadata, | ||||||
| _fast_init=_fast_init, | ||||||
| low_cpu_mem_usage=low_cpu_mem_usage, | ||||||
| device_map=device_map, | ||||||
| offload_folder=offload_folder, | ||||||
| offload_state_dict=offload_state_dict, | ||||||
| dtype=torch_dtype, | ||||||
| hf_quantizer=hf_quantizer, | ||||||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||||||
| gguf_path=gguf_path, | ||||||
| weights_only=weights_only, | ||||||
| ) | ||||||
|
|
||||||
| # make sure token embedding weights are still tied if needed | ||||||
| model.tie_weights() | ||||||
|
|
@@ -4324,6 +4357,14 @@ def from_pretrained( | |||||
| } | ||||||
| return model, loading_info | ||||||
|
|
||||||
| if tp_plan is not None: | ||||||
| assert tp_device is not None, "tp_device not set!" | ||||||
kwen2501 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| # Assuming sharding the model onto the world | ||||||
| world_size = torch.distributed.get_world_size() | ||||||
| device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,)) | ||||||
| # Apply Tensor Parallelism | ||||||
| model.tensor_parallel(device_mesh) | ||||||
|
|
||||||
| return model | ||||||
|
|
||||||
| @classmethod | ||||||
|
|
@@ -5013,6 +5054,42 @@ def _is_quantized_training_enabled(self): | |||||
|
|
||||||
| return self.hf_quantizer.is_trainable | ||||||
|
|
||||||
| def tensor_parallel(self, device_mesh): | ||||||
kwen2501 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| """ | ||||||
| Tensor parallelize the model across the given device mesh. | ||||||
|
|
||||||
| Args: | ||||||
| device_mesh (`torch.distributed.DeviceMesh`): | ||||||
| The device mesh to use for tensor parallelism. | ||||||
| """ | ||||||
|
|
||||||
| # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. | ||||||
| # No op if `_tp_plan` attribute does not exist under the module. | ||||||
| # This is a helper function to be used with `model.apply` to recursively | ||||||
| # parallelize a model. | ||||||
| def tplize(mod: torch.nn.Module) -> None: | ||||||
| tp_plan = getattr(mod, "_tp_plan", None) | ||||||
| if tp_plan is None: | ||||||
| return | ||||||
| logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}") | ||||||
| # In model configs, we use a neutral type (string) to specify | ||||||
| # parallel styles, here we translate them into torch TP types. | ||||||
| # Using tree_map because `tp_plan` is a dict. | ||||||
| tp_plan = torch.utils._pytree.tree_map( | ||||||
| translate_to_torch_parallel_style, | ||||||
| tp_plan, | ||||||
| ) | ||||||
| # Apply TP to current module. | ||||||
| torch.distributed.tensor.parallel.parallelize_module( | ||||||
| mod, | ||||||
| device_mesh=device_mesh, | ||||||
| parallelize_plan=tp_plan, | ||||||
| ) | ||||||
|
|
||||||
| # `apply` is a native method of `nn.Module` that recursively applies a | ||||||
| # function to every submodule. | ||||||
| self.apply(tplize) | ||||||
|
|
||||||
| @property | ||||||
| @lru_cache | ||||||
| def loss_function(self): | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.