Skip to content

Commit 8c6cfb8

Browse files
committed
fix: test cases
Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 14fa148 commit 8c6cfb8

File tree

10 files changed

+124
-19
lines changed

10 files changed

+124
-19
lines changed

src/accelerate/accelerator.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,12 @@
108108
save_fsdp_optimizer,
109109
wait_for_everyone,
110110
)
111-
from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME, BETA_TP_AVAILABLE_PYTORCH_VERSION
111+
from .utils.constants import (
112+
BETA_TP_AVAILABLE_PYTORCH_VERSION,
113+
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,
114+
FSDP_PYTORCH_VERSION,
115+
PROFILE_PATTERN_NAME,
116+
)
112117
from .utils.modeling import get_state_dict_offloaded_model
113118
from .utils.other import is_compiled_module
114119

@@ -359,10 +364,15 @@ def __init__(
359364
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
360365
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
361366

362-
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
367+
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(
368+
torch_tp_plugin, TorchTensorParallelPlugin
369+
):
363370
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
364371
raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}")
365372

373+
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
374+
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
375+
366376
if fsdp_plugin is None: # init from env variables
367377
fsdp_plugin = (
368378
FullyShardedDataParallelPlugin() if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" else None
@@ -373,12 +383,14 @@ def __init__(
373383
os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided
374384

375385
if torch_tp_plugin is None:
376-
torch_tp_plugin = (TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None)
386+
torch_tp_plugin = (
387+
TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None
388+
)
377389
else:
378390
if not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
379391
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
380392
os.environ["ACCELERATE_USE_TP"] = "true"
381-
393+
382394
if megatron_lm_plugin is None: # init from env variables
383395
megatron_lm_plugin = (
384396
MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None
@@ -1489,8 +1501,14 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
14891501
if self.ddp_handler is not None:
14901502
self.ddp_handler.register_comm_hook(model)
14911503
elif self.distributed_type == DistributedType.TP:
1492-
if not model.supports_tp_plan:
1493-
raise NotImplementedError("Provided model does not support tensor parallelism")
1504+
if hasattr(model, "supports_tp_plan") and not model.supports_tp_plan:
1505+
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
1506+
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
1507+
raise NotImplementedError(
1508+
"Provided model does not support tensor parallelism. \
1509+
Tensor parallelism plan can be added as base_model_tp_plan to model config class \
1510+
and _tp_plan attribute to model class."
1511+
)
14941512
model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
14951513
elif self.distributed_type == DistributedType.FSDP:
14961514
# We need to fix the optimizer *before* sharding the model

src/accelerate/commands/config/cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def get_cluster_input():
487487
distributed_type = DistributedType.TP
488488
if distributed_type == DistributedType.TP:
489489
tp_config["tp_size"] = _ask_field(
490-
"What should be your Tensor Parallel degree? [1e8]: ",
490+
"What should be your Tensor Parallel degree? [1]: ",
491491
int,
492492
default=1,
493493
)

src/accelerate/commands/launch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def launch_command_parser(subparsers=None):
594594
type=str,
595595
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
596596
)
597-
597+
598598
# tp args
599599
tp_args = parser.add_argument_group("TP Arguments", "Arguments related to Tensor Parallelism using PyToch.")
600600
tp_args.add_argument(

src/accelerate/data_loader.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -752,11 +752,11 @@ def __init__(
752752
self.iteration = 0
753753

754754
# if a device mesh is provided extract each dimension (dp, fsdp, tp)
755-
# device mesh may hold any number of dimensions, however,
755+
# device mesh may hold any number of dimensions, however,
756756
# below code is for targetted support for dp, fsdp and tp
757-
758-
# device mesh will be used only if there is tp involved
759-
# or any multi-dimensional parallelism involving tp
757+
758+
# device mesh will be used only if there is tp involved
759+
# or any multi-dimensional parallelism involving tp
760760
# (dp, tp) (fsdp, tp) (dp, fsdp, tp)
761761
# otherwise the default behavour not using device mesh should be sufficient
762762
# since multi dimensional parallelism devoid of tp would anyway need
@@ -789,8 +789,10 @@ def _fetch_batches(self, iterator):
789789
if self.split_batches:
790790
# One batch of the main iterator is dispatched and split.
791791
if self.submesh_tp:
792-
logger.warning("Use of split_batches for TP would need the dataloader to produce duplicate batches,"
793-
"otherwise, use dispatch_batches=True instead.")
792+
logger.warning(
793+
"Use of split_batches for TP would need the dataloader to produce duplicate batches,"
794+
"otherwise, use dispatch_batches=True instead."
795+
)
794796
self._update_state_dict()
795797
batch = next(iterator)
796798
else:
@@ -996,7 +998,7 @@ def prepare_data_loader(
996998
data_seed: Optional[int] = None,
997999
non_blocking: bool = False,
9981000
use_stateful_dataloader: bool = False,
999-
torch_device_mesh: torch.distributed.DeviceMesh = None,
1001+
torch_device_mesh=None,
10001002
) -> DataLoader:
10011003
"""
10021004
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
@@ -1090,7 +1092,7 @@ def prepare_data_loader(
10901092
state = PartialState()
10911093
if num_processes is None:
10921094
num_processes = state.num_processes
1093-
1095+
10941096
# when device mesh is used, specifically with TP
10951097
# then there is need to update process_index and num_processes
10961098
# to bring in the effect of generating same batch across TP ranks
@@ -1110,7 +1112,7 @@ def prepare_data_loader(
11101112
submesh_dp_size = torch_device_mesh["dp"].size()
11111113
if "fsdp" in torch_device_mesh.mesh_dim_names:
11121114
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
1113-
num_processes = (submesh_fsdp_size * submesh_dp_size)
1115+
num_processes = submesh_fsdp_size * submesh_dp_size
11141116
if process_index is None:
11151117
process_index = state.process_index
11161118
if torch_device_mesh:

src/accelerate/test_utils/scripts/external_deps/test_performance.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name:
4343
model_name (`str`, *optional*):
4444
"""
4545
tokenizer = AutoTokenizer.from_pretrained(model_name)
46+
4647
datasets = load_dataset("glue", "mrpc")
4748

4849
def tokenize_function(examples):
@@ -93,6 +94,10 @@ def training_function(config, args):
9394
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
9495
model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True)
9596

97+
if args.add_pad_token:
98+
if model.config.pad_token_id is None:
99+
model.config.pad_token_id = 0
100+
96101
# Instantiate optimizer
97102
optimizer_cls = (
98103
AdamW
@@ -243,6 +248,12 @@ def main():
243248
default=3,
244249
help="Number of train epochs.",
245250
)
251+
parser.add_argument(
252+
"--add_pad_token",
253+
type=bool,
254+
default=False,
255+
help="To add pad token if not exists.",
256+
)
246257
args = parser.parse_args()
247258
config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
248259
training_function(config, args)

src/accelerate/test_utils/testing.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,13 @@ def require_deepspeed(test_case):
342342
return unittest.skipUnless(is_deepspeed_available(), "test requires DeepSpeed")(test_case)
343343

344344

345+
def require_tp(test_case):
346+
"""
347+
Decorator marking a test that requires TP installed. These tests are skipped when TP isn't installed
348+
"""
349+
return unittest.skipUnless(is_torch_version(">=", "2.3.0"), "test requires torch version >= 2.3.0")(test_case)
350+
351+
345352
def require_torch_min_version(test_case=None, version=None):
346353
"""
347354
Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an

src/accelerate/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
4848
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
4949
BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"
50+
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.47.0"
5051

5152
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
5253

src/accelerate/utils/dataclasses.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,10 +1826,14 @@ class TorchTensorParallelPlugin:
18261826
metadata={"help": "tensor parallel size will be used in the device mesh preparation"},
18271827
)
18281828

1829-
# type has to be "torch.distributed.DeviceMesh"
1830-
torch_device_mesh: torch.distributed.DeviceMesh = field(default=None)
1829+
# torch_device_mesh is fo type "torch.distributed.DeviceMesh"
1830+
torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None)
18311831

18321832
def __post_init__(self):
1833+
self.tp_size = self.tp_size if os.environ.get("TP_SIZE", "1") == "1" else int(os.environ.get("TP_SIZE", "1"))
1834+
if self.tp_size == 1:
1835+
raise ValueError("Provide TP degree > 1.")
1836+
18331837
if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION):
18341838
raise ValueError(
18351839
f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel."

src/accelerate/utils/launch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]:
286286

287287
if args.use_tp:
288288
current_env["ACCELERATE_USE_TP"] = "true"
289+
current_env["TP_SIZE"] = str(args.tp_size)
289290

290291
if args.use_megatron_lm:
291292
prefix = "MEGATRON_LM_"

tests/tp/test_tp.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from accelerate.test_utils.testing import (
17+
TempDirTestCase,
18+
execute_subprocess_async,
19+
get_launch_command,
20+
path_in_accelerate_package,
21+
require_multi_device,
22+
require_non_torch_xla,
23+
require_tp,
24+
require_transformers,
25+
slow,
26+
)
27+
from accelerate.utils import patch_environment
28+
29+
30+
@require_non_torch_xla
31+
@require_tp
32+
@require_multi_device
33+
@require_transformers
34+
@slow
35+
class TPIntegrationTest(TempDirTestCase):
36+
test_scripts_folder = path_in_accelerate_package("test_utils", "scripts", "external_deps")
37+
38+
def setUp(self):
39+
super().setUp()
40+
self.test_tp_size = 2
41+
self.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
42+
self.batch_size = 1
43+
from transformers.trainer_utils import set_seed
44+
45+
set_seed(42)
46+
47+
def test_working_of_tp(self):
48+
self.test_file_path = self.test_scripts_folder / "test_performance.py"
49+
cmd = get_launch_command(
50+
num_processes=self.test_tp_size, num_machines=1, machine_rank=0, use_tp=True, tp_size=self.test_tp_size
51+
)
52+
cmd.extend(
53+
[
54+
self.test_file_path,
55+
f"--output_dir={self.tmpdir}",
56+
f"--model_name_or_path={self.model_name_or_path}",
57+
"--add_pad_token=true",
58+
]
59+
)
60+
with patch_environment(omp_num_threads=1):
61+
execute_subprocess_async(cmd)

0 commit comments

Comments
 (0)