Skip to content

Commit aa7d872

Browse files
committed
feat: add support for tensor parallel flow using accelerate
Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent d3af76d commit aa7d872

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

src/transformers/trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@
233233
AutocastKwargs,
234234
DistributedDataParallelKwargs,
235235
DistributedType,
236+
TorchTensorParallelPlugin,
236237
load_fsdp_model,
237238
load_fsdp_optimizer,
238239
save_fsdp_model,
@@ -5076,6 +5077,11 @@ def create_accelerator_and_postprocess(self):
50765077
args["dataloader_config"] = dataloader_config
50775078
else:
50785079
args.update(accelerator_config)
5080+
# tp is initialized at Accelerator init phase so
5081+
# args should be prepared here
5082+
if self.args.tp_size > 1:
5083+
self.is_tp_enabled = True
5084+
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size)
50795085

50805086
# create accelerator object
50815087
self.accelerator = Accelerator(**args)
@@ -5090,7 +5096,7 @@ def create_accelerator_and_postprocess(self):
50905096
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
50915097
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
50925098
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
5093-
5099+
self.is_tp_enabled = getattr(self.accelerator.state, "tp_plugin", None) is not None
50945100
# post accelerator creation setup
50955101
if self.is_fsdp_enabled:
50965102
fsdp_plugin = self.accelerator.state.fsdp_plugin

src/transformers/training_args.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,9 @@ class TrainingArguments:
566566
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
567567
used when the xla flag is set to true, and an auto wrapping policy is specified through
568568
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
569-
569+
tp_size (`int`, *optional*):
570+
Use tp_size to enable pytorch 2.0 tensor parallelism. Set a value greater than 1 to activate TP. The same is
571+
used to prepare device mesh internally.
570572
deepspeed (`str` or `dict`, *optional*):
571573
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
572574
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
@@ -1240,6 +1242,16 @@ class TrainingArguments:
12401242
)
12411243
},
12421244
)
1245+
tp_size: Optional[int] = field(
1246+
default=0,
1247+
metadata={
1248+
"help": (
1249+
"Use tp_size to enable pytorch 2.0 tensor parallelism."
1250+
"Set a value greater than 1 to activate TP."
1251+
"The same is used to prepare device mesh internally."
1252+
)
1253+
},
1254+
)
12431255
fsdp_transformer_layer_cls_to_wrap: Optional[str] = field(
12441256
default=None,
12451257
metadata={
@@ -1957,6 +1969,8 @@ def __post_init__(self):
19571969
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
19581970
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
19591971

1972+
if self.tp_size > 1:
1973+
os.environ["ACCELERATE_USE_TP"] = "true"
19601974
# accelerate integration for FSDP
19611975
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
19621976
os.environ["ACCELERATE_USE_FSDP"] = "true"

0 commit comments

Comments
 (0)