Skip to content

Commit 67adb47

Browse files
authored
(Part 1) fix: make TP training compatible with new transformers (#3457)
* feat: support new tp refactor for training Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: @S1ro1 review cmt Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: @S1ro1 review cmt - tp_plan flag docstr Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: @SunMarc review cmt on un used flag Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: pick approach 3 as discussed in the PR see #3457 (comment) for more details Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: styling errors Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: bump up transformers for tp_size feature Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> --------- Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent ee4cab9 commit 67adb47

File tree

10 files changed

+57
-83
lines changed

10 files changed

+57
-83
lines changed

src/accelerate/accelerator.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,7 @@ def __init__(
374374
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
375375
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
376376

377-
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(
378-
torch_tp_plugin, TorchTensorParallelPlugin
379-
):
377+
if isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
380378
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
381379
raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}")
382380

@@ -396,14 +394,8 @@ def __init__(
396394
if not is_torch_version(">=", FSDP2_PYTORCH_VERSION):
397395
raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}")
398396

399-
if torch_tp_plugin is None:
400-
torch_tp_plugin = (
401-
TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None
402-
)
403-
else:
404-
if not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
405-
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
406-
os.environ["ACCELERATE_USE_TP"] = "true"
397+
if torch_tp_plugin is not None and not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
398+
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
407399

408400
if megatron_lm_plugin is None: # init from env variables
409401
megatron_lm_plugin = (
@@ -1598,15 +1590,17 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
15981590
if self.ddp_handler is not None:
15991591
self.ddp_handler.register_comm_hook(model)
16001592
elif self.distributed_type == DistributedType.TP:
1601-
if hasattr(model, "supports_tp_plan") and not model.supports_tp_plan:
1602-
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
1603-
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
1593+
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
1594+
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
1595+
if not hasattr(model, "tp_size"):
16041596
raise NotImplementedError(
1605-
"Provided model does not support tensor parallelism. \
1606-
Tensor parallelism plan can be added as base_model_tp_plan to model config class \
1607-
and _tp_plan attribute to model class."
1597+
"Model should undergo tensor parallel before passing it to accelerate."
1598+
"You can use .from_pretrained(..., tp_plan='auto') if the model supports"
1599+
)
1600+
if model.tp_size != self.state.torch_tp_plugin.tp_size:
1601+
raise ValueError(
1602+
f"tp_size in the plugin {self.state.torch_tp_plugin.tp_size} should be same as model's tp size {model.tp_size}"
16081603
)
1609-
model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
16101604
elif self.is_fsdp2:
16111605
model = fsdp2_prepare_model(self, model)
16121606

@@ -2223,8 +2217,7 @@ def _prepare_device_mesh(self):
22232217
return self.state.torch_tp_plugin.torch_device_mesh
22242218
elif self.distributed_type == DistributedType.DEEPSPEED and hasattr(self.state, "ds_device_mesh"):
22252219
return self.state.ds_device_mesh
2226-
else:
2227-
return None
2220+
return None
22282221

22292222
def _prepare_msamp(self, *args, device_placement):
22302223
if not is_msamp_available():

src/accelerate/commands/config/cluster.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def get_cluster_input():
369369
)
370370

371371
fsdp_config = {}
372-
tp_config = {}
372+
373373
if distributed_type in [
374374
DistributedType.MULTI_GPU,
375375
DistributedType.MULTI_NPU,
@@ -498,21 +498,7 @@ def get_cluster_input():
498498
default=False,
499499
error_message="Please enter yes or no.",
500500
)
501-
if not use_fsdp:
502-
use_tp = _ask_field(
503-
"Do you want to use TensorParallel? [yes/NO]: ",
504-
_convert_yes_no_to_bool,
505-
default=False,
506-
error_message="Please enter yes or no.",
507-
)
508-
if use_tp:
509-
distributed_type = DistributedType.TP
510-
if distributed_type == DistributedType.TP:
511-
tp_config["tp_size"] = _ask_field(
512-
"What should be your Tensor Parallel degree? [1]: ",
513-
int,
514-
default=1,
515-
)
501+
516502
megatron_lm_config = {}
517503
if distributed_type in [DistributedType.MULTI_GPU]:
518504
use_megatron_lm = _ask_field(
@@ -857,7 +843,6 @@ def get_cluster_input():
857843
fp8_config=fp8_config,
858844
deepspeed_config=deepspeed_config,
859845
fsdp_config=fsdp_config,
860-
tp_config=tp_config,
861846
megatron_lm_config=megatron_lm_config,
862847
ipex_config=ipex_config,
863848
mpirun_config=mpirun_config,

src/accelerate/commands/config/config_args.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,6 @@ class ClusterConfig(BaseConfig):
194194
deepspeed_config: dict = None
195195
# args for fsdp
196196
fsdp_config: dict = None
197-
# args for tp
198-
tp_config: dict = None
199197
# args for megatron_lm
200198
megatron_lm_config: dict = None
201199
# args for ipex
@@ -223,8 +221,6 @@ def __post_init__(self):
223221
self.deepspeed_config = {}
224222
if self.fsdp_config is None:
225223
self.fsdp_config = {}
226-
if self.tp_config is None:
227-
self.tp_config = {}
228224
if self.megatron_lm_config is None:
229225
self.megatron_lm_config = {}
230226
if self.ipex_config is None:

src/accelerate/commands/launch.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@
7575
"tpu": "TPU",
7676
"use_deepspeed": "DeepSpeed Arguments",
7777
"use_fsdp": "FSDP Arguments",
78-
"use_tp": "PyTorch TP Arguments",
7978
"use_megatron_lm": "Megatron-LM Arguments",
8079
"fp8_backend": "FP8 Arguments",
8180
}
@@ -264,12 +263,6 @@ def launch_command_parser(subparsers=None):
264263
action="store_true",
265264
help="Whether to use fsdp.",
266265
)
267-
paradigm_args.add_argument(
268-
"--use_tp",
269-
default=False,
270-
action="store_true",
271-
help="Whether to use PyTorch TP.",
272-
)
273266
paradigm_args.add_argument(
274267
"--use_megatron_lm",
275268
default=False,
@@ -612,15 +605,6 @@ def launch_command_parser(subparsers=None):
612605
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
613606
)
614607

615-
# tp args
616-
tp_args = parser.add_argument_group("TP Arguments", "Arguments related to Tensor Parallelism using PyToch.")
617-
tp_args.add_argument(
618-
"--tp_size",
619-
default=1,
620-
type=int,
621-
help="PyTorch Tensor Parallelism (TP) degree. Set a value greater than 1 to activate. (useful only when `use_tp` flag is passed)",
622-
)
623-
624608
# megatron_lm args
625609
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
626610
megatron_lm_args.add_argument(
@@ -1002,9 +986,9 @@ def sagemaker_launcher(sagemaker_config: SageMakerConfig, args):
1002986

1003987
def _validate_launch_command(args):
1004988
# Sanity checks
1005-
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp, args.use_tp]) > 1:
989+
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1:
1006990
raise ValueError(
1007-
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`, `--use_tp` at a time."
991+
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time."
1008992
)
1009993
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
1010994
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
@@ -1021,7 +1005,6 @@ def _validate_launch_command(args):
10211005
and not args.tpu_use_cluster
10221006
and not args.use_deepspeed
10231007
and not args.use_fsdp
1024-
and not args.use_tp
10251008
and not args.use_megatron_lm
10261009
):
10271010
args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
@@ -1041,7 +1024,6 @@ def _validate_launch_command(args):
10411024
)
10421025
args.tpu = defaults.distributed_type == DistributedType.XLA
10431026
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
1044-
args.use_tp = defaults.distributed_type == DistributedType.TP
10451027
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
10461028
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
10471029
if args.gpu_ids is None:
@@ -1195,8 +1177,6 @@ def launch_command(args):
11951177
deepspeed_launcher(args)
11961178
elif args.use_fsdp and not args.cpu:
11971179
multi_gpu_launcher(args)
1198-
elif args.use_tp and not args.cpu:
1199-
multi_gpu_launcher(args)
12001180
elif args.use_megatron_lm and not args.cpu:
12011181
multi_gpu_launcher(args)
12021182
elif args.multi_gpu and not args.cpu:

src/accelerate/state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ def __init__(
971971
self.distributed_type = DistributedType.MEGATRON_LM
972972
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
973973
self.megatron_lm_plugin = megatron_lm_plugin
974-
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None:
974+
if self.torch_tp_plugin is not None:
975975
self.distributed_type = DistributedType.TP
976976
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
977977
if is_ipex_available():

src/accelerate/test_utils/scripts/external_deps/test_performance.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import argparse
1515
import json
1616
import os
17+
from contextlib import nullcontext
1718
from pathlib import Path
1819

1920
import evaluate
@@ -24,7 +25,7 @@
2425
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
2526

2627
from accelerate import Accelerator, DistributedType
27-
from accelerate.utils import SAFE_WEIGHTS_NAME, set_seed
28+
from accelerate.utils import SAFE_WEIGHTS_NAME, TorchTensorParallelPlugin, set_seed
2829
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
2930

3031

@@ -80,7 +81,7 @@ def collate_fn(examples):
8081

8182
def training_function(config, args):
8283
# Initialize accelerator
83-
accelerator = Accelerator()
84+
accelerator = Accelerator(torch_tp_plugin=TorchTensorParallelPlugin(tp_size=args.tp_size))
8485

8586
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
8687
lr = config["lr"]
@@ -91,9 +92,10 @@ def training_function(config, args):
9192

9293
set_seed(seed)
9394
train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name)
94-
9595
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
96-
model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True)
96+
model = AutoModelForSequenceClassification.from_pretrained(
97+
model_name, return_dict=True, tp_plan=args.tp_plan, tp_size=args.tp_size
98+
)
9799

98100
if args.add_pad_token:
99101
if model.config.pad_token_id is None:
@@ -150,7 +152,13 @@ def training_function(config, args):
150152
outputs = model(**batch)
151153
loss = outputs.loss
152154
accelerator.backward(loss)
153-
optimizer.step()
155+
context = nullcontext
156+
if args.tp_plan is not None:
157+
from torch.distributed._tensor.experimental import implicit_replication
158+
159+
context = implicit_replication
160+
with context():
161+
optimizer.step()
154162
lr_scheduler.step()
155163
optimizer.zero_grad()
156164

@@ -213,12 +221,15 @@ def training_function(config, args):
213221
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
214222
json.dump(performance_metric, f)
215223

216-
# Finally try saving the model
217-
accelerator.save_model(model, args.output_dir)
224+
# TODO: skip saving of the model test for TP until the feature lands
225+
if args.tp_plan is None:
226+
# Finally try saving the model
227+
accelerator.save_model(model, args.output_dir)
218228
accelerator.wait_for_everyone()
219-
assert Path(args.output_dir, SAFE_WEIGHTS_NAME).exists(), (
220-
"Model was not saved when calling `Accelerator.save_model`"
221-
)
229+
if args.tp_plan is None:
230+
assert Path(args.output_dir, SAFE_WEIGHTS_NAME).exists(), (
231+
"Model was not saved when calling `Accelerator.save_model`"
232+
)
222233
accelerator.end_training()
223234

224235

@@ -255,6 +266,18 @@ def main():
255266
default=False,
256267
help="To add pad token if not exists.",
257268
)
269+
parser.add_argument(
270+
"--tp_plan",
271+
type=str,
272+
default=None,
273+
help="pass 'auto' to use TP",
274+
)
275+
parser.add_argument(
276+
"--tp_size",
277+
type=int,
278+
default=None,
279+
help="TP size to be used to shard the model",
280+
)
258281
args = parser.parse_args()
259282
config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
260283
training_function(config, args)

src/accelerate/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
5050
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
5151
BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"
52-
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.47.0"
52+
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.52.0"
5353

5454
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
5555

src/accelerate/utils/dataclasses.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2034,7 +2034,6 @@ class TorchTensorParallelPlugin:
20342034
torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None)
20352035

20362036
def __post_init__(self):
2037-
self.tp_size = self.tp_size if os.environ.get("TP_SIZE", "1") == "1" else int(os.environ.get("TP_SIZE", "1"))
20382037
if self.tp_size == 1:
20392038
raise ValueError("Provide TP degree > 1.")
20402039

@@ -2052,6 +2051,8 @@ def __post_init__(self):
20522051

20532052
mesh_dim_name = "tp"
20542053

2054+
# device mesh is not used for model sharding
2055+
# it is only used for preparing data loader
20552056
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,))
20562057

20572058

src/accelerate/utils/launch.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,6 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
305305
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
306306
current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()
307307

308-
if args.use_tp:
309-
current_env["ACCELERATE_USE_TP"] = "true"
310-
current_env["TP_SIZE"] = str(args.tp_size)
311-
312308
if args.use_megatron_lm:
313309
prefix = "MEGATRON_LM_"
314310
current_env["ACCELERATE_USE_MEGATRON_LM"] = "true"

tests/tp/test_tp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ def setUp(self):
4848

4949
def test_working_of_tp(self):
5050
self.test_file_path = self.test_scripts_folder / "test_performance.py"
51-
cmd = get_launch_command(
52-
num_processes=self.test_tp_size, num_machines=1, machine_rank=0, use_tp=True, tp_size=self.test_tp_size
53-
)
51+
cmd = get_launch_command(num_processes=self.test_tp_size, num_machines=1, machine_rank=0)
5452
cmd.extend(
5553
[
5654
self.test_file_path,
5755
f"--output_dir={self.tmpdir}",
5856
f"--model_name_or_path={self.model_name_or_path}",
5957
"--add_pad_token=true",
58+
"--tp_plan=auto",
59+
f"--tp_size={self.test_tp_size}",
6060
]
6161
)
6262
with patch_environment(omp_num_threads=1):

0 commit comments

Comments
 (0)