Skip to content

Commit 72d52c2

Browse files
committed
feat: support new tp refactor for training
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent d7c741a commit 72d52c2

File tree

10 files changed

+23
-67
lines changed

10 files changed

+23
-67
lines changed

src/accelerate/accelerator.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,7 @@ def __init__(
374374
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
375375
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
376376

377-
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(
378-
torch_tp_plugin, TorchTensorParallelPlugin
379-
):
377+
if isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
380378
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
381379
raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}")
382380

@@ -396,14 +394,8 @@ def __init__(
396394
if not is_torch_version(">=", FSDP2_PYTORCH_VERSION):
397395
raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}")
398396

399-
if torch_tp_plugin is None:
400-
torch_tp_plugin = (
401-
TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None
402-
)
403-
else:
404-
if not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
405-
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
406-
os.environ["ACCELERATE_USE_TP"] = "true"
397+
if torch_tp_plugin is not None and not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
398+
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
407399

408400
if megatron_lm_plugin is None: # init from env variables
409401
megatron_lm_plugin = (
@@ -1600,15 +1592,14 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
16001592
if self.ddp_handler is not None:
16011593
self.ddp_handler.register_comm_hook(model)
16021594
elif self.distributed_type == DistributedType.TP:
1595+
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
1596+
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
16031597
if hasattr(model, "supports_tp_plan") and not model.supports_tp_plan:
1604-
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
1605-
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
16061598
raise NotImplementedError(
16071599
"Provided model does not support tensor parallelism. \
16081600
Tensor parallelism plan can be added as base_model_tp_plan to model config class \
16091601
and _tp_plan attribute to model class."
16101602
)
1611-
model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
16121603
elif self.is_fsdp2:
16131604
model = fsdp2_prepare_model(self, model)
16141605

@@ -2225,8 +2216,7 @@ def _prepare_device_mesh(self):
22252216
return self.state.torch_tp_plugin.torch_device_mesh
22262217
elif self.distributed_type == DistributedType.DEEPSPEED and hasattr(self.state, "ds_device_mesh"):
22272218
return self.state.ds_device_mesh
2228-
else:
2229-
return None
2219+
return None
22302220

22312221
def _prepare_msamp(self, *args, device_placement):
22322222
if not is_msamp_available():

src/accelerate/commands/config/cluster.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def get_cluster_input():
382382
)
383383

384384
fsdp_config = {}
385-
tp_config = {}
385+
386386
if distributed_type in [
387387
DistributedType.MULTI_GPU,
388388
DistributedType.MULTI_NPU,
@@ -510,21 +510,7 @@ def get_cluster_input():
510510
default=False,
511511
error_message="Please enter yes or no.",
512512
)
513-
if not use_fsdp:
514-
use_tp = _ask_field(
515-
"Do you want to use TensorParallel? [yes/NO]: ",
516-
_convert_yes_no_to_bool,
517-
default=False,
518-
error_message="Please enter yes or no.",
519-
)
520-
if use_tp:
521-
distributed_type = DistributedType.TP
522-
if distributed_type == DistributedType.TP:
523-
tp_config["tp_size"] = _ask_field(
524-
"What should be your Tensor Parallel degree? [1]: ",
525-
int,
526-
default=1,
527-
)
513+
528514
megatron_lm_config = {}
529515
if distributed_type in [DistributedType.MULTI_GPU]:
530516
use_megatron_lm = _ask_field(
@@ -863,7 +849,6 @@ def get_cluster_input():
863849
fp8_config=fp8_config,
864850
deepspeed_config=deepspeed_config,
865851
fsdp_config=fsdp_config,
866-
tp_config=tp_config,
867852
megatron_lm_config=megatron_lm_config,
868853
ipex_config=ipex_config,
869854
mpirun_config=mpirun_config,

src/accelerate/commands/config/config_args.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,6 @@ class ClusterConfig(BaseConfig):
194194
deepspeed_config: dict = None
195195
# args for fsdp
196196
fsdp_config: dict = None
197-
# args for tp
198-
tp_config: dict = None
199197
# args for megatron_lm
200198
megatron_lm_config: dict = None
201199
# args for ipex

src/accelerate/commands/launch.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@
7575
"tpu": "TPU",
7676
"use_deepspeed": "DeepSpeed Arguments",
7777
"use_fsdp": "FSDP Arguments",
78-
"use_tp": "PyTorch TP Arguments",
7978
"use_megatron_lm": "Megatron-LM Arguments",
8079
"fp8_backend": "FP8 Arguments",
8180
}
@@ -264,12 +263,6 @@ def launch_command_parser(subparsers=None):
264263
action="store_true",
265264
help="Whether to use fsdp.",
266265
)
267-
paradigm_args.add_argument(
268-
"--use_tp",
269-
default=False,
270-
action="store_true",
271-
help="Whether to use PyTorch TP.",
272-
)
273266
paradigm_args.add_argument(
274267
"--use_megatron_lm",
275268
default=False,
@@ -611,15 +604,6 @@ def launch_command_parser(subparsers=None):
611604
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
612605
)
613606

614-
# tp args
615-
tp_args = parser.add_argument_group("TP Arguments", "Arguments related to Tensor Parallelism using PyToch.")
616-
tp_args.add_argument(
617-
"--tp_size",
618-
default=1,
619-
type=int,
620-
help="PyTorch Tensor Parallelism (TP) degree. Set a value greater than 1 to activate. (useful only when `use_tp` flag is passed)",
621-
)
622-
623607
# megatron_lm args
624608
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
625609
megatron_lm_args.add_argument(
@@ -1001,9 +985,9 @@ def sagemaker_launcher(sagemaker_config: SageMakerConfig, args):
1001985

1002986
def _validate_launch_command(args):
1003987
# Sanity checks
1004-
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp, args.use_tp]) > 1:
988+
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1:
1005989
raise ValueError(
1006-
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`, `--use_tp` at a time."
990+
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time."
1007991
)
1008992
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
1009993
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
@@ -1020,7 +1004,6 @@ def _validate_launch_command(args):
10201004
and not args.tpu_use_cluster
10211005
and not args.use_deepspeed
10221006
and not args.use_fsdp
1023-
and not args.use_tp
10241007
and not args.use_megatron_lm
10251008
):
10261009
args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
@@ -1040,7 +1023,6 @@ def _validate_launch_command(args):
10401023
)
10411024
args.tpu = defaults.distributed_type == DistributedType.XLA
10421025
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
1043-
args.use_tp = defaults.distributed_type == DistributedType.TP
10441026
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
10451027
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
10461028
if args.gpu_ids is None:
@@ -1191,8 +1173,6 @@ def launch_command(args):
11911173
deepspeed_launcher(args)
11921174
elif args.use_fsdp and not args.cpu:
11931175
multi_gpu_launcher(args)
1194-
elif args.use_tp and not args.cpu:
1195-
multi_gpu_launcher(args)
11961176
elif args.use_megatron_lm and not args.cpu:
11971177
multi_gpu_launcher(args)
11981178
elif args.multi_gpu and not args.cpu:

src/accelerate/state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,7 @@ def __init__(
966966
self.distributed_type = DistributedType.MEGATRON_LM
967967
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
968968
self.megatron_lm_plugin = megatron_lm_plugin
969-
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None:
969+
if self.torch_tp_plugin is not None:
970970
self.distributed_type = DistributedType.TP
971971
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
972972
if is_ipex_available():

src/accelerate/test_utils/scripts/external_deps/test_performance.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,8 @@ def training_function(config, args):
9191

9292
set_seed(seed)
9393
train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name)
94-
9594
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
96-
model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True)
95+
model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True, tp_plan=args.tp_plan)
9796

9897
if args.add_pad_token:
9998
if model.config.pad_token_id is None:
@@ -255,6 +254,12 @@ def main():
255254
default=False,
256255
help="To add pad token if not exists.",
257256
)
257+
parser.add_argument(
258+
"--tp_plan",
259+
type=str,
260+
default=None,
261+
help="To use TP or not",
262+
)
258263
args = parser.parse_args()
259264
config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
260265
training_function(config, args)

src/accelerate/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
5050
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
5151
BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"
52-
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.47.0"
52+
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.50.0"
5353

5454
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
5555

src/accelerate/utils/dataclasses.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2028,7 +2028,6 @@ class TorchTensorParallelPlugin:
20282028
torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None)
20292029

20302030
def __post_init__(self):
2031-
self.tp_size = self.tp_size if os.environ.get("TP_SIZE", "1") == "1" else int(os.environ.get("TP_SIZE", "1"))
20322031
if self.tp_size == 1:
20332032
raise ValueError("Provide TP degree > 1.")
20342033

@@ -2046,6 +2045,8 @@ def __post_init__(self):
20462045

20472046
mesh_dim_name = "tp"
20482047

2048+
# device mesh is not used for model sharding
2049+
# it is only used for preparing data loader
20492050
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,))
20502051

20512052

src/accelerate/utils/launch.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,10 +306,6 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]:
306306
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
307307
current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()
308308

309-
if args.use_tp:
310-
current_env["ACCELERATE_USE_TP"] = "true"
311-
current_env["TP_SIZE"] = str(args.tp_size)
312-
313309
if args.use_megatron_lm:
314310
prefix = "MEGATRON_LM_"
315311
current_env["ACCELERATE_USE_MEGATRON_LM"] = "true"

tests/tp/test_tp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,15 @@ def setUp(self):
4949
def test_working_of_tp(self):
5050
self.test_file_path = self.test_scripts_folder / "test_performance.py"
5151
cmd = get_launch_command(
52-
num_processes=self.test_tp_size, num_machines=1, machine_rank=0, use_tp=True, tp_size=self.test_tp_size
52+
num_processes=self.test_tp_size, num_machines=1, machine_rank=0, tp_size=self.test_tp_size
5353
)
5454
cmd.extend(
5555
[
5656
self.test_file_path,
5757
f"--output_dir={self.tmpdir}",
5858
f"--model_name_or_path={self.model_name_or_path}",
5959
"--add_pad_token=true",
60+
"--tp_plan='auto'",
6061
]
6162
)
6263
with patch_environment(omp_num_threads=1):

0 commit comments

Comments
 (0)