Skip to content

Commit c096d40

Browse files
committed
feat: support tensor parallel using Pytorch 2.0
Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent bf4572b commit c096d40

File tree

6 files changed

+79
-3
lines changed

6 files changed

+79
-3
lines changed

src/accelerate/accelerator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
ProjectConfiguration,
6868
RNGType,
6969
TorchDynamoPlugin,
70+
TorchTensorParallelPlugin,
7071
apply_fp8_autowrap,
7172
check_os_kernel,
7273
clean_state_dict_for_safetensors,
@@ -188,6 +189,9 @@ class Accelerator:
188189
fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
189190
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
190191
using *accelerate config*
192+
torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*):
193+
Tweak your torch tensor parallel. This argument is optional and can be configured directly using
194+
*accelerate config*
191195
megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
192196
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
193197
directly using *accelerate config*
@@ -254,6 +258,7 @@ def __init__(
254258
dataloader_config: DataLoaderConfiguration | None = None,
255259
deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
256260
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
261+
torch_tp_plugin: TorchTensorParallelPlugin | None = None,
257262
megatron_lm_plugin: MegatronLMPlugin | None = None,
258263
rng_types: list[str | RNGType] | None = None,
259264
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
@@ -418,6 +423,7 @@ def __init__(
418423
dynamo_plugin=dynamo_plugin,
419424
deepspeed_plugin=deepspeed_plugins,
420425
fsdp_plugin=fsdp_plugin,
426+
torch_tp_plugin=torch_tp_plugin,
421427
megatron_lm_plugin=megatron_lm_plugin,
422428
_from_accelerator=True,
423429
**kwargs,
@@ -1461,6 +1467,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
14611467
)
14621468
if self.ddp_handler is not None:
14631469
self.ddp_handler.register_comm_hook(model)
1470+
elif self.distributed_type == DistributedType.TP:
1471+
model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
14641472
elif self.distributed_type == DistributedType.FSDP:
14651473
# We need to fix the optimizer *before* sharding the model
14661474
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
@@ -2117,6 +2125,7 @@ def prepare_data_loader(
21172125
data_seed=self.dataloader_config.data_seed,
21182126
non_blocking=self.non_blocking,
21192127
use_stateful_dataloader=self.use_stateful_dataloader,
2128+
torch_device_mesh=self.state.torch_tp_plugin.torch_device_mesh if self.state.torch_tp_plugin else None,
21202129
)
21212130
self._dataloaders.append(prepared_data_loader)
21222131
return prepared_data_loader

src/accelerate/data_loader.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ def __init__(
713713
_drop_last: bool = False,
714714
_non_blocking: bool = False,
715715
slice_fn=None,
716+
torch_device_mesh=None,
716717
**kwargs,
717718
):
718719
shuffle = False
@@ -732,15 +733,37 @@ def __init__(
732733
self._drop_last = _drop_last
733734
self._non_blocking = _non_blocking
734735
self.skip_batches = skip_batches
736+
self.torch_device_mesh = torch_device_mesh
735737

736738
self.slice_fn = slice_tensors if slice_fn is None else slice_fn
737739
self.iteration = 0
738740

741+
# if a device mesh is provided extract each dimension (tp and dp)
742+
# device mesh will be used only if there is tp involved
743+
# otherwise the default behavour should be sufficient
744+
self.submesh_tp = None
745+
self.submesh_dp = None
746+
if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names:
747+
# extract torch sub device mesh objects
748+
self.submesh_tp = self.torch_device_mesh["tp"]
749+
if "dp" in self.torch_device_mesh.mesh_dim_names:
750+
self.submesh_dp = self.torch_device_mesh["dp"]
751+
if self.submesh_tp and self.submesh_dp:
752+
raise ValueError("TP + DDP / TP + FSDP is not yet supported")
753+
739754
def _fetch_batches(self, iterator):
740755
batches, batch = None, None
741756
# On process 0, we gather the batch to dispatch.
742757
if self.state.process_index == 0:
758+
# Procedure to support TP only is simpler
759+
# since we want to dispatch the same batch of samples across all ranks
760+
# this removes complexity of handling multiple tp rank groups when TP + DP
761+
# combination is involved.
762+
743763
try:
764+
# for TP case avoid using split_batches
765+
# since it would mean that the dataloader should be spilling out
766+
# duplicates of batches.
744767
if self.split_batches:
745768
# One batch of the main iterator is dispatched and split.
746769
self._update_state_dict()
@@ -749,9 +772,15 @@ def _fetch_batches(self, iterator):
749772
# num_processes batches of the main iterator are concatenated then dispatched and split.
750773
# We add the batches one by one so we have the remainder available when drop_last=False.
751774
batches = []
752-
for _ in range(self.state.num_processes):
775+
if self.submesh_tp:
776+
# when tp, extract single batch and then replicate
753777
self._update_state_dict()
754-
batches.append(next(iterator))
778+
batch = next(iterator)
779+
batches = [batch] * self.state.num_processes
780+
else:
781+
for _ in range(self.state.num_processes):
782+
self._update_state_dict()
783+
batches.append(next(iterator))
755784
try:
756785
batch = concatenate(batches, dim=0)
757786
except RuntimeError as e:
@@ -942,6 +971,7 @@ def prepare_data_loader(
942971
data_seed: Optional[int] = None,
943972
non_blocking: bool = False,
944973
use_stateful_dataloader: bool = False,
974+
torch_device_mesh: torch.distributed.DeviceMesh = None,
945975
) -> DataLoader:
946976
"""
947977
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
@@ -1009,6 +1039,8 @@ def prepare_data_loader(
10091039
"If set to true, the dataloader prepared by the Accelerator will be backed by "
10101040
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
10111041
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
1042+
torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`):
1043+
PyTorch device mesh.
10121044
10131045
10141046
Returns:
@@ -1144,6 +1176,7 @@ def prepare_data_loader(
11441176
_non_blocking=non_blocking,
11451177
slice_fn=slice_fn_for_dispatch,
11461178
use_stateful_dataloader=use_stateful_dataloader,
1179+
torch_device_mesh=torch_device_mesh,
11471180
**kwargs,
11481181
)
11491182
elif sampler_is_batch_sampler:

src/accelerate/state.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,7 @@ def __init__(
850850
dynamo_plugin=None,
851851
deepspeed_plugin=None,
852852
fsdp_plugin=None,
853+
torch_tp_plugin=None,
853854
megatron_lm_plugin=None,
854855
_from_accelerator: bool = False,
855856
**kwargs,
@@ -864,6 +865,7 @@ def __init__(
864865
if not self.initialized:
865866
self.deepspeed_plugins = None
866867
self.use_ipex = None
868+
self.torch_tp_plugin = torch_tp_plugin
867869
mixed_precision = (
868870
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
869871
if mixed_precision is None
@@ -921,6 +923,8 @@ def __init__(
921923
self.distributed_type = DistributedType.MEGATRON_LM
922924
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
923925
self.megatron_lm_plugin = megatron_lm_plugin
926+
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None:
927+
self.distributed_type = DistributedType.TP
924928
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
925929
if is_ipex_available():
926930
# check if user disables it explicitly

src/accelerate/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
SageMakerDistributedType,
5858
TensorInformation,
5959
TorchDynamoPlugin,
60+
TorchTensorParallelPlugin,
6061
add_model_config_to_megatron_parser,
6162
)
6263
from .environment import (

src/accelerate/utils/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
4747
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
4848
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
49+
BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"
4950

5051
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
5152

@@ -76,7 +77,7 @@
7677
"master_port",
7778
]
7879

79-
CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM"]
80+
CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM", "TP"]
8081
TORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + [
8182
"MULTI_NPU",
8283
"MULTI_MLU",

src/accelerate/utils/dataclasses.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import torch
3131

3232
from .constants import (
33+
BETA_TP_AVAILABLE_PYTORCH_VERSION,
3334
FSDP_AUTO_WRAP_POLICY,
3435
FSDP_BACKWARD_PREFETCH,
3536
FSDP_SHARDING_STRATEGY,
@@ -540,6 +541,7 @@ class DistributedType(str, enum.Enum):
540541
MULTI_XPU = "MULTI_XPU"
541542
DEEPSPEED = "DEEPSPEED"
542543
FSDP = "FSDP"
544+
TP = "TP"
543545
XLA = "XLA"
544546
MEGATRON_LM = "MEGATRON_LM"
545547

@@ -1810,6 +1812,32 @@ def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=F
18101812
self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy)
18111813

18121814

1815+
@dataclass
1816+
class TorchTensorParallelPlugin:
1817+
"""
1818+
This plugin is used to enable tensor parallelism using PyTorch >= 2.0.
1819+
"""
1820+
1821+
tp_size: int = field(
1822+
default=1,
1823+
metadata={"help": "tensor parallel size will be used in the device mesh preparation"},
1824+
)
1825+
1826+
# type has to be "torch.distributed.DeviceMesh"
1827+
torch_device_mesh: torch.distributed.DeviceMesh = field(default=None)
1828+
1829+
def __post_init__(self):
1830+
if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION):
1831+
raise ValueError(
1832+
f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel."
1833+
)
1834+
from torch.distributed.device_mesh import init_device_mesh
1835+
1836+
mesh_dim_name = "tp"
1837+
device = "cuda" # support for other devices has to be investigated
1838+
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,))
1839+
1840+
18131841
@dataclass
18141842
class MegatronLMPlugin:
18151843
"""

0 commit comments

Comments
 (0)