@@ -566,7 +566,9 @@ class TrainingArguments:
566566 Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
567567 used when the xla flag is set to true, and an auto wrapping policy is specified through
568568 fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
569-
569+ tp_size (`int`, *optional*):
570+ Use tp_size to enable pytorch 2.0 tensor parallelism. Set a value greater than 1 to activate TP. The same is
571+ used to prepare device mesh internally.
570572 deepspeed (`str` or `dict`, *optional*):
571573 Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
572574 evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
@@ -1240,6 +1242,16 @@ class TrainingArguments:
12401242 )
12411243 },
12421244 )
1245+ tp_size : Optional [int ] = field (
1246+ default = 0 ,
1247+ metadata = {
1248+ "help" : (
1249+ "Use tp_size to enable pytorch 2.0 tensor parallelism."
1250+ "Set a value greater than 1 to activate TP."
1251+ "The same is used to prepare device mesh internally."
1252+ )
1253+ },
1254+ )
12431255 fsdp_transformer_layer_cls_to_wrap : Optional [str ] = field (
12441256 default = None ,
12451257 metadata = {
@@ -1957,6 +1969,8 @@ def __post_init__(self):
19571969 if self .fsdp_config ["xla_fsdp_grad_ckpt" ]:
19581970 warnings .warn ("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true." )
19591971
1972+ if self .tp_size > 1 :
1973+ os .environ ["ACCELERATE_USE_TP" ] = "true"
19601974 # accelerate integration for FSDP
19611975 if len (self .fsdp ) > 0 and not self .fsdp_config ["xla" ]:
19621976 os .environ ["ACCELERATE_USE_FSDP" ] = "true"
0 commit comments