Skip to content

Commit 974edd9

Browse files
committed
feat: add tests for cli usage of TP and plugin
Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 780ae7b commit 974edd9

File tree

6 files changed

+108
-13
lines changed

6 files changed

+108
-13
lines changed

src/accelerate/accelerator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
save_fsdp_optimizer,
109109
wait_for_everyone,
110110
)
111-
from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME, BETA_TP_AVAILABLE_PYTORCH_VERSION
111+
from .utils.constants import BETA_TP_AVAILABLE_PYTORCH_VERSION, FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME
112112
from .utils.modeling import get_state_dict_offloaded_model
113113
from .utils.other import is_compiled_module
114114

@@ -349,7 +349,9 @@ def __init__(
349349
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
350350
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
351351

352-
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
352+
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(
353+
torch_tp_plugin, TorchTensorParallelPlugin
354+
):
353355
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
354356
raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}")
355357

@@ -363,12 +365,14 @@ def __init__(
363365
os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided
364366

365367
if torch_tp_plugin is None:
366-
torch_tp_plugin = (TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None)
368+
torch_tp_plugin = (
369+
TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None
370+
)
367371
else:
368372
if not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
369373
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
370374
os.environ["ACCELERATE_USE_TP"] = "true"
371-
375+
372376
if megatron_lm_plugin is None: # init from env variables
373377
megatron_lm_plugin = (
374378
MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None

src/accelerate/commands/launch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ def launch_command_parser(subparsers=None):
595595
type=str,
596596
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
597597
)
598-
598+
599599
# tp args
600600
tp_args = parser.add_argument_group("TP Arguments", "Arguments related to Tensor Parallelism using PyToch.")
601601
tp_args.add_argument(

src/accelerate/data_loader.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -740,11 +740,11 @@ def __init__(
740740
self.iteration = 0
741741

742742
# if a device mesh is provided extract each dimension (dp, fsdp, tp)
743-
# device mesh may hold any number of dimensions, however,
743+
# device mesh may hold any number of dimensions, however,
744744
# below code is for targetted support for dp, fsdp and tp
745-
746-
# device mesh will be used only if there is tp involved
747-
# or any multi-dimensional parallelism involving tp
745+
746+
# device mesh will be used only if there is tp involved
747+
# or any multi-dimensional parallelism involving tp
748748
# (dp, tp) (fsdp, tp) (dp, fsdp, tp)
749749
# otherwise the default behavour not using device mesh should be sufficient
750750
# since multi dimensional parallelism devoid of tp would anyway need
@@ -777,8 +777,10 @@ def _fetch_batches(self, iterator):
777777
if self.split_batches:
778778
# One batch of the main iterator is dispatched and split.
779779
if self.submesh_tp:
780-
logger.warning("Use of split_batches for TP would need the dataloader to produce duplicate batches,"
781-
"otherwise, use dispatch_batches=True instead.")
780+
logger.warning(
781+
"Use of split_batches for TP would need the dataloader to produce duplicate batches,"
782+
"otherwise, use dispatch_batches=True instead."
783+
)
782784
self._update_state_dict()
783785
batch = next(iterator)
784786
else:
@@ -1078,7 +1080,7 @@ def prepare_data_loader(
10781080
state = PartialState()
10791081
if num_processes is None:
10801082
num_processes = state.num_processes
1081-
1083+
10821084
# when device mesh is used, specifically with TP
10831085
# then there is need to update process_index and num_processes
10841086
# to bring in the effect of generating same batch across TP ranks
@@ -1098,7 +1100,7 @@ def prepare_data_loader(
10981100
submesh_dp_size = torch_device_mesh["dp"].size()
10991101
if "fsdp" in torch_device_mesh.mesh_dim_names:
11001102
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
1101-
num_processes = (submesh_fsdp_size * submesh_dp_size)
1103+
num_processes = submesh_fsdp_size * submesh_dp_size
11021104
if process_index is None:
11031105
process_index = state.process_index
11041106
if torch_device_mesh:

src/accelerate/test_utils/testing.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,13 @@ def require_fsdp(test_case):
339339
return unittest.skipUnless(is_torch_version(">=", "1.12.0"), "test requires torch version >= 1.12.0")(test_case)
340340

341341

342+
def require_tp(test_case):
343+
"""
344+
Decorator marking a test that requires FSDP installed. These tests are skipped when FSDP isn't installed
345+
"""
346+
return unittest.skipUnless(is_torch_version(">=", "2.3.0"), "test requires torch version >= 2.3.0")(test_case)
347+
348+
342349
def require_torch_min_version(test_case=None, version=None):
343350
"""
344351
Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an

src/accelerate/utils/dataclasses.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1827,6 +1827,10 @@ class TorchTensorParallelPlugin:
18271827
torch_device_mesh: torch.distributed.DeviceMesh = field(default=None)
18281828

18291829
def __post_init__(self):
1830+
self.tp_size = self.tp_size if os.environ.get("TP_SIZE", "1") == "1" else int(os.environ.get("TP_SIZE", "1"))
1831+
if self.tp_size == 1:
1832+
raise ValueError("Provide TP degree > 1.")
1833+
18301834
if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION):
18311835
raise ValueError(
18321836
f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel."

tests/tp/test_tp.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from transformers.testing_utils import mockenv_context
17+
from transformers.trainer_utils import set_seed
18+
19+
from accelerate.test_utils.testing import (
20+
AccelerateTestCase,
21+
TempDirTestCase,
22+
execute_subprocess_async,
23+
get_launch_command,
24+
path_in_accelerate_package,
25+
require_multi_device,
26+
require_non_cpu,
27+
require_non_torch_xla,
28+
require_tp,
29+
slow,
30+
)
31+
from accelerate.utils import patch_environment
32+
from accelerate.utils.dataclasses import TorchTensorParallelPlugin
33+
34+
35+
set_seed(42)
36+
37+
38+
@require_tp
39+
@require_non_cpu
40+
@require_non_torch_xla
41+
class TPPluginIntegration(AccelerateTestCase):
42+
def setUp(self):
43+
super().setUp()
44+
45+
self.dist_env = dict(
46+
MASTER_ADDR="localhost",
47+
MASTER_PORT="10999",
48+
RANK="0",
49+
LOCAL_RANK="0",
50+
WORLD_SIZE="1",
51+
)
52+
53+
self.tp_env = dict(ACCELERATE_USE_TP="true", TP_SIZE="2", **self.dist_env)
54+
55+
def test_device_mesh_init(self):
56+
with mockenv_context(**self.tp_env):
57+
tp_plugin = TorchTensorParallelPlugin()
58+
assert str(tp_plugin.torch_device_mesh["tp"].size()) == self.tp_env["TP_SIZE"]
59+
60+
61+
@require_non_torch_xla
62+
@require_tp
63+
@require_multi_device
64+
@slow
65+
class TPIntegrationTest(TempDirTestCase):
66+
test_scripts_folder = path_in_accelerate_package("test_utils", "scripts", "external_deps")
67+
68+
def setUp(self):
69+
super().setUp()
70+
self.test_tp_size = 2
71+
72+
def test_working_of_tp(self):
73+
self.test_file_path = self.test_scripts_folder / "test_performance.py"
74+
cmd = get_launch_command(
75+
num_processes=self.test_tp_size, num_machines=1, machine_rank=0, use_tp=True, tp_size=self.test_tp_size
76+
)
77+
with patch_environment(omp_num_threads=1):
78+
execute_subprocess_async(cmd)

0 commit comments

Comments
 (0)