Skip to content

Commit 20142ab

Browse files
authored
Simplify Tensor Parallel implementation with PyTorch TP (#34184)
* Simplify Tensor Parallel implementation with PyTorch TP * Move tp_plan to config * Lint * Format and warning * Disable copy-from check * Conditionally get attr from config * make fix-copies * Move base_model_tp_plan to PretrainedConfig * Move TP into from_pretrained * Add device context for load * Do not serialize * Move _tp_plan setting to post_init * Add has_tp_plan * Add test_tp * Add 'Multi-gpu inference' doc * Add backward support for device type identification * Auto-detect accelerator * supports_tp_plan * copyright year * Fix copy
1 parent 7df93d6 commit 20142ab

File tree

18 files changed

+357
-92
lines changed

18 files changed

+357
-92
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@
218218
title: CPU inference
219219
- local: perf_infer_gpu_one
220220
title: GPU inference
221+
- local: perf_infer_gpu_multi
222+
title: Multi-GPU inference
221223
title: Optimizing inference
222224
- local: big_models
223225
title: Instantiate a big model
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
11+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
12+
rendered properly in your Markdown viewer.
13+
14+
-->
15+
16+
# Multi-GPU inference
17+
18+
Built-in Tensor Parallelism (TP) is now available with certain models using PyTorch. Tensor parallelism shards a model onto multiple GPUs, enabling larger model sizes, and parallelizes computations such as matrix multiplication.
19+
20+
To enable tensor parallel, pass the argument `tp_plan="auto"` to [`~AutoModelForCausalLM.from_pretrained`]:
21+
22+
```python
23+
import os
24+
import torch
25+
from transformers import AutoModelForCausalLM, AutoTokenizer
26+
27+
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
28+
29+
# Initialize distributed
30+
rank = int(os.environ["RANK"])
31+
device = torch.device(f"cuda:{rank}")
32+
torch.distributed.init_process_group("nccl", device_id=device)
33+
34+
# Retrieve tensor parallel model
35+
model = AutoModelForCausalLM.from_pretrained(
36+
model_id,
37+
tp_plan="auto",
38+
)
39+
40+
# Prepare input tokens
41+
tokenizer = AutoTokenizer.from_pretrained(model_id)
42+
prompt = "Can I help"
43+
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
44+
45+
# Distributed run
46+
outputs = model(inputs)
47+
```
48+
49+
You can use `torchrun` to launch the above script with multiple processes, each mapping to a GPU:
50+
51+
```
52+
torchrun --nproc-per-node 4 demo.py
53+
```
54+
55+
PyTorch tensor parallel is currently supported for the following models:
56+
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
57+
58+
You can request to add tensor parallel support for another model by opening a GitHub Issue or Pull Request.
59+
60+
### Expected speedups
61+
62+
You can benefit from considerable speedups for inference, especially for inputs with large batch size or long sequences.
63+
64+
For a single forward pass on [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) with a sequence length of 512 and various batch sizes, the expected speedup is as follows:
65+
66+
<div style="text-align: center">
67+
<img src="huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Meta-Llama-3-8B-Instruct, seqlen = 512, python, w_ compile.png">
68+
</div>

docs/source/en/performance.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ sections we go through the steps to run inference on CPU and single/multi-GPU se
5353

5454
* [Inference on a single CPU](perf_infer_cpu)
5555
* [Inference on a single GPU](perf_infer_gpu_one)
56-
* [Multi-GPU inference](perf_infer_gpu_one)
56+
* [Multi-GPU inference](perf_infer_gpu_multi)
5757
* [XLA Integration for TensorFlow Models](tf_xla)
5858

5959

src/transformers/configuration_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class PretrainedConfig(PushToHubMixin):
7171
outputs of the model during inference.
7272
- **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
7373
naming of attributes.
74+
- **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
75+
parallel plan applied to the sub-module when `model.tensor_parallel` is called.
7476
7577
Common attributes (present in all subclasses):
7678
@@ -194,6 +196,7 @@ class PretrainedConfig(PushToHubMixin):
194196
sub_configs: Dict[str, "PretrainedConfig"] = {}
195197
is_composition: bool = False
196198
attribute_map: Dict[str, str] = {}
199+
base_model_tp_plan: Optional[Dict[str, Any]] = None
197200
_auto_class: Optional[str] = None
198201

199202
def __setattr__(self, key, value):
@@ -848,6 +851,9 @@ def to_diff_dict(self) -> Dict[str, Any]:
848851

849852
if "_attn_implementation_internal" in serializable_config_dict:
850853
del serializable_config_dict["_attn_implementation_internal"]
854+
# Do not serialize `base_model_tp_plan` for now
855+
if "base_model_tp_plan" in serializable_config_dict:
856+
del serializable_config_dict["base_model_tp_plan"]
851857

852858
return serializable_config_dict
853859

@@ -867,6 +873,9 @@ def to_dict(self) -> Dict[str, Any]:
867873
del output["_commit_hash"]
868874
if "_attn_implementation_internal" in output:
869875
del output["_attn_implementation_internal"]
876+
# Do not serialize `base_model_tp_plan` for now
877+
if "base_model_tp_plan" in output:
878+
del output["base_model_tp_plan"]
870879

871880
# Transformers version when serializing the model
872881
output["transformers_version"] = __version__

src/transformers/modeling_utils.py

Lines changed: 116 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
prune_conv1d_layer,
5656
prune_layer,
5757
prune_linear_layer,
58+
translate_to_torch_parallel_style,
5859
)
5960
from .quantizers import AutoHfQuantizer, HfQuantizer
6061
from .quantizers.quantizers_utils import get_module_from_name
@@ -1326,6 +1327,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
13261327
# Has support for a `QuantoQuantizedCache` instance as `past_key_values`
13271328
_supports_quantized_cache = False
13281329

1330+
# A tensor parallel plan to be applied to the model when TP is enabled. For
1331+
# top-level models, this attribute is currently defined in respective model
1332+
# code. For base models, this attribute comes from
1333+
# `config.base_model_tp_plan` during `post_init`.
1334+
_tp_plan = None
1335+
13291336
@property
13301337
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
13311338
"""
@@ -1370,6 +1377,9 @@ def post_init(self):
13701377
"""
13711378
self.init_weights()
13721379
self._backward_compatibility_gradient_checkpointing()
1380+
# If current model is a base model, attach `base_model_tp_plan` from config
1381+
if self.base_model is self:
1382+
self._tp_plan = self.config.base_model_tp_plan
13731383

13741384
def dequantize(self):
13751385
"""
@@ -3399,6 +3409,11 @@ def from_pretrained(
33993409
# Cache path to the GGUF file
34003410
gguf_path = None
34013411

3412+
tp_plan = kwargs.pop("tp_plan", None)
3413+
if tp_plan is not None and tp_plan != "auto":
3414+
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
3415+
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
3416+
34023417
if is_fsdp_enabled():
34033418
low_cpu_mem_usage = True
34043419

@@ -4000,6 +4015,7 @@ def from_pretrained(
40004015

40014016
# Instantiate model.
40024017
init_contexts = [no_init_weights(_enable=_fast_init)]
4018+
tp_device = None
40034019

40044020
if is_deepspeed_zero3_enabled() and not is_quantized:
40054021
import deepspeed
@@ -4012,6 +4028,16 @@ def from_pretrained(
40124028
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
40134029
)
40144030
init_contexts.append(init_empty_weights())
4031+
elif tp_plan is not None:
4032+
if not torch.distributed.is_initialized():
4033+
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
4034+
4035+
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
4036+
device_type = torch._C._get_accelerator().type
4037+
device_module = torch.get_device_module(device_type)
4038+
# Get device with index assuming equal number of devices per host
4039+
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
4040+
init_contexts.append(tp_device)
40154041

40164042
if is_deepspeed_zero3_enabled() and is_quantized:
40174043
init_contexts.append(set_quantized_state())
@@ -4145,32 +4171,38 @@ def from_pretrained(
41454171
if dtype_orig is not None:
41464172
torch.set_default_dtype(dtype_orig)
41474173

4148-
(
4149-
model,
4150-
missing_keys,
4151-
unexpected_keys,
4152-
mismatched_keys,
4153-
offload_index,
4154-
error_msgs,
4155-
) = cls._load_pretrained_model(
4156-
model,
4157-
state_dict,
4158-
loaded_state_dict_keys, # XXX: rename?
4159-
resolved_archive_file,
4160-
pretrained_model_name_or_path,
4161-
ignore_mismatched_sizes=ignore_mismatched_sizes,
4162-
sharded_metadata=sharded_metadata,
4163-
_fast_init=_fast_init,
4164-
low_cpu_mem_usage=low_cpu_mem_usage,
4165-
device_map=device_map,
4166-
offload_folder=offload_folder,
4167-
offload_state_dict=offload_state_dict,
4168-
dtype=torch_dtype,
4169-
hf_quantizer=hf_quantizer,
4170-
keep_in_fp32_modules=keep_in_fp32_modules,
4171-
gguf_path=gguf_path,
4172-
weights_only=weights_only,
4173-
)
4174+
load_contexts = []
4175+
# Make sure we load onto targeted device
4176+
if tp_device is not None:
4177+
load_contexts.append(tp_device)
4178+
4179+
with ContextManagers(load_contexts):
4180+
(
4181+
model,
4182+
missing_keys,
4183+
unexpected_keys,
4184+
mismatched_keys,
4185+
offload_index,
4186+
error_msgs,
4187+
) = cls._load_pretrained_model(
4188+
model,
4189+
state_dict,
4190+
loaded_state_dict_keys, # XXX: rename?
4191+
resolved_archive_file,
4192+
pretrained_model_name_or_path,
4193+
ignore_mismatched_sizes=ignore_mismatched_sizes,
4194+
sharded_metadata=sharded_metadata,
4195+
_fast_init=_fast_init,
4196+
low_cpu_mem_usage=low_cpu_mem_usage,
4197+
device_map=device_map,
4198+
offload_folder=offload_folder,
4199+
offload_state_dict=offload_state_dict,
4200+
dtype=torch_dtype,
4201+
hf_quantizer=hf_quantizer,
4202+
keep_in_fp32_modules=keep_in_fp32_modules,
4203+
gguf_path=gguf_path,
4204+
weights_only=weights_only,
4205+
)
41744206

41754207
# make sure token embedding weights are still tied if needed
41764208
model.tie_weights()
@@ -4254,6 +4286,16 @@ def from_pretrained(
42544286
}
42554287
return model, loading_info
42564288

4289+
if tp_plan is not None:
4290+
assert tp_device is not None, "tp_device not set!"
4291+
if not model.supports_tp_plan:
4292+
raise NotImplementedError("This model does not have a tensor parallel plan.")
4293+
# Assuming sharding the model onto the world
4294+
world_size = torch.distributed.get_world_size()
4295+
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
4296+
# Apply Tensor Parallelism
4297+
model.tensor_parallel(device_mesh)
4298+
42574299
return model
42584300

42594301
@classmethod
@@ -4943,6 +4985,54 @@ def _is_quantized_training_enabled(self):
49434985

49444986
return self.hf_quantizer.is_trainable
49454987

4988+
@property
4989+
def supports_tp_plan(self):
4990+
"""
4991+
Returns whether the model has a tensor parallelism plan.
4992+
"""
4993+
if self._tp_plan is not None:
4994+
return True
4995+
# Check if base model has a TP plan
4996+
if getattr(self.base_model, "_tp_plan", None) is not None:
4997+
return True
4998+
return False
4999+
5000+
def tensor_parallel(self, device_mesh):
5001+
"""
5002+
Tensor parallelize the model across the given device mesh.
5003+
5004+
Args:
5005+
device_mesh (`torch.distributed.DeviceMesh`):
5006+
The device mesh to use for tensor parallelism.
5007+
"""
5008+
5009+
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
5010+
# No op if `_tp_plan` attribute does not exist under the module.
5011+
# This is a helper function to be used with `model.apply` to recursively
5012+
# parallelize a model.
5013+
def tplize(mod: torch.nn.Module) -> None:
5014+
tp_plan = getattr(mod, "_tp_plan", None)
5015+
if tp_plan is None:
5016+
return
5017+
logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}")
5018+
# In model configs, we use a neutral type (string) to specify
5019+
# parallel styles, here we translate them into torch TP types.
5020+
# Using tree_map because `tp_plan` is a dict.
5021+
tp_plan = torch.utils._pytree.tree_map(
5022+
translate_to_torch_parallel_style,
5023+
tp_plan,
5024+
)
5025+
# Apply TP to current module.
5026+
torch.distributed.tensor.parallel.parallelize_module(
5027+
mod,
5028+
device_mesh=device_mesh,
5029+
parallelize_plan=tp_plan,
5030+
)
5031+
5032+
# `apply` is a native method of `nn.Module` that recursively applies a
5033+
# function to every submodule.
5034+
self.apply(tplize)
5035+
49465036
@property
49475037
def loss_function(self):
49485038
if getattr(self.config, "loss_type", None) is not None:

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1068,7 +1068,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
10681068
return causal_mask
10691069

10701070

1071-
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
1071+
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
10721072
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
10731073
_tied_weights_keys = ["lm_head.weight"]
10741074

src/transformers/models/gemma/modeling_gemma.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,10 @@ def __init__(self, config: GemmaConfig):
720720
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
721721
)
722722
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
723+
723724
self.gradient_checkpointing = False
725+
if getattr(config, "pretraining_tp", 1) != 1:
726+
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
724727

725728
# Initialize weights and apply final processing
726729
self.post_init()
@@ -982,6 +985,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
982985

983986
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
984987
_tied_weights_keys = ["lm_head.weight"]
988+
_tp_plan = {"lm_head": "colwise_rep"}
985989

986990
def __init__(self, config):
987991
super().__init__(config)

src/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,10 @@ def __init__(self, config: Gemma2Config):
740740
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
741741
)
742742
self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
743+
743744
self.gradient_checkpointing = False
745+
if getattr(config, "pretraining_tp", 1) != 1:
746+
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
744747

745748
# Initialize weights and apply final processing
746749
self.post_init()
@@ -961,6 +964,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
961964

962965
class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
963966
_tied_weights_keys = ["lm_head.weight"]
967+
_tp_plan = {"lm_head": "colwise_rep"}
964968

965969
def __init__(self, config):
966970
super().__init__(config)

src/transformers/models/glm/modeling_glm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,8 @@ def __init__(self, config: GlmConfig):
708708
dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta
709709
)
710710
self.gradient_checkpointing = False
711+
if getattr(config, "pretraining_tp", 1) != 1:
712+
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
711713

712714
# Initialize weights and apply final processing
713715
self.post_init()
@@ -967,6 +969,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
967969

968970
class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
969971
_tied_weights_keys = ["lm_head.weight"]
972+
_tp_plan = {"lm_head": "colwise_rep"}
970973

971974
def __init__(self, config: GlmConfig):
972975
super().__init__(config)

0 commit comments

Comments
 (0)