Skip to content

Commit 1ad75c9

Browse files
beniericpintaoz-aws
authored andcommitted
Simplify Config Class Names and DistributedRunner structures (#1573)
1 parent ce55d45 commit 1ad75c9

22 files changed

+693
-413
lines changed

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ env/
3333
*.html
3434
**/_repack_script_launcher.sh
3535
src/sagemaker/modules/train/container_drivers/sm_train.sh
36-
src/sagemaker/modules/train/container_drivers/sourcecodeconfig.json
37-
src/sagemaker/modules/train/container_drivers/distribution.json
36+
src/sagemaker/modules/train/container_drivers/sourcecode.json
37+
src/sagemaker/modules/train/container_drivers/distributed_runner.json
3838
tests/data/**/_repack_model.py
3939
tests/data/experiment/sagemaker-dev-1.0.tar.gz
4040
src/sagemaker/serve/tmp_workspace

src/sagemaker/modules/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,7 @@
1616
from sagemaker_core.main.utils import logger as sagemaker_core_logger
1717

1818
logger = sagemaker_core_logger
19+
20+
from sagemaker.modules.train.model_trainer import ( # noqa: F401 E402 # pylint: disable=C0413
21+
ModelTrainer,
22+
)

src/sagemaker/modules/configs.py

Lines changed: 15 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from __future__ import absolute_import
2323

24-
from typing import Optional, Union, Dict, Any, List
24+
from typing import Optional, Union
2525
from pydantic import BaseModel, model_validator
2626

2727
import sagemaker_core.shapes as shapes
@@ -54,15 +54,10 @@
5454
CheckpointConfig,
5555
)
5656

57-
from sagemaker.modules import logger
5857
from sagemaker.modules.utils import convert_unassigned_to_none
5958

6059
__all__ = [
61-
"SourceCodeConfig",
62-
"TorchDistributionConfig",
63-
"MPIDistributionConfig",
64-
"SMDistributedSettings",
65-
"DistributionConfig",
60+
"SourceCode",
6661
"StoppingCondition",
6762
"RetryStrategy",
6863
"OutputDataConfig",
@@ -87,107 +82,16 @@
8782
"InstanceGroup",
8883
"TensorBoardOutputConfig",
8984
"CheckpointConfig",
90-
"ComputeConfig",
91-
"NetworkingConfig",
85+
"Compute",
86+
"Networking",
9287
"InputData",
9388
]
9489

9590

96-
class SMDistributedSettings(BaseModel):
97-
"""SMDistributedSettings.
91+
class SourceCode(BaseModel):
92+
"""SourceCode.
9893
99-
The SMDistributedSettings is used to configure distributed training when
100-
using the smdistributed library.
101-
102-
Attributes:
103-
enable_dataparallel (Optional[bool]):
104-
Whether to enable data parallelism.
105-
enable_modelparallel (Optional[bool]):
106-
Whether to enable model parallelism.
107-
modelparallel_parameters (Optional[Dict[str, Any]]):
108-
The parameters for model parallelism.
109-
"""
110-
111-
enable_dataparallel: Optional[bool] = False
112-
enable_modelparallel: Optional[bool] = False
113-
modelparallel_parameters: Optional[Dict[str, Any]] = None
114-
115-
116-
class DistributionConfig(BaseModel):
117-
"""Base class for distribution configurations."""
118-
119-
_distribution_type: str
120-
121-
122-
class TorchDistributionConfig(DistributionConfig):
123-
"""TorchDistributionConfig.
124-
125-
The TorchDistributionConfig uses `torchrun` or `torch.distributed.launch` in the backend to
126-
launch distributed training.
127-
128-
SMDistributed Library Information:
129-
- `TorchDistributionConfig` can be used for SMModelParallel V2.
130-
- For SMDataParallel or SMModelParallel V1, it is recommended to use the
131-
`MPIDistributionConfig.`
132-
133-
134-
Attributes:
135-
smdistributed_settings (Optional[SMDistributedSettings]):
136-
The settings for smdistributed library.
137-
process_count_per_node (int):
138-
The number of processes to run on each node in the training job.
139-
Will default to the number of CPUs or GPUs available in the container.
140-
"""
141-
142-
_distribution_type: str = "torch_distributed"
143-
144-
smdistributed_settings: Optional[SMDistributedSettings] = None
145-
process_count_per_node: Optional[int] = None
146-
147-
@model_validator(mode="after")
148-
def _validate_model(cls, model): # pylint: disable=E0213
149-
"""Validate the model."""
150-
if (
151-
getattr(model, "smddistributed_settings", None)
152-
and model.smddistributed_settings.enable_dataparallel
153-
):
154-
logger.warning(
155-
"For smdistributed data parallelism, it is recommended to use "
156-
+ "MPIDistributionConfig."
157-
)
158-
return model
159-
160-
161-
class MPIDistributionConfig(DistributionConfig):
162-
"""MPIDistributionConfig.
163-
164-
The MPIDistributionConfig uses `mpirun` in the backend to launch distributed training.
165-
166-
SMDistributed Library Information:
167-
- `MPIDistributionConfig` can be used for SMDataParallel and SMModelParallel V1.
168-
- For SMModelParallel V2, it is recommended to use the `TorchDistributionConfig`.
169-
170-
Attributes:
171-
smdistributed_settings (Optional[SMDistributedSettings]):
172-
The settings for smdistributed library.
173-
process_count_per_node (int):
174-
The number of processes to run on each node in the training job.
175-
Will default to the number of CPUs or GPUs available in the container.
176-
mpi_additional_options (Optional[str]):
177-
The custom MPI options to use for the training job.
178-
"""
179-
180-
_distribution_type: str = "mpi"
181-
182-
smdistributed_settings: Optional[SMDistributedSettings] = None
183-
process_count_per_node: Optional[int] = None
184-
mpi_additional_options: Optional[List[str]] = None
185-
186-
187-
class SourceCodeConfig(BaseModel):
188-
"""SourceCodeConfig.
189-
190-
This config allows the user to specify the source code location, dependencies,
94+
The SourceCode class allows the user to specify the source code location, dependencies,
19195
entry script, or commands to be executed in the training job container.
19296
19397
Attributes:
@@ -210,10 +114,10 @@ class SourceCodeConfig(BaseModel):
210114
command: Optional[str] = None
211115

212116

213-
class ComputeConfig(shapes.ResourceConfig):
214-
"""ComputeConfig.
117+
class Compute(shapes.ResourceConfig):
118+
"""Compute.
215119
216-
The ComputeConfig is a subclass of `sagemaker_core.shapes.ResourceConfig`
120+
The Compute class is a subclass of `sagemaker_core.shapes.ResourceConfig`
217121
and allows the user to specify the compute resources for the training job.
218122
219123
Attributes:
@@ -245,7 +149,7 @@ class ComputeConfig(shapes.ResourceConfig):
245149
enable_managed_spot_training: Optional[bool] = None
246150

247151
@model_validator(mode="after")
248-
def _model_validator(self) -> "ComputeConfig":
152+
def _model_validator(self) -> "Compute":
249153
"""Convert Unassigned values to None."""
250154
return convert_unassigned_to_none(self)
251155

@@ -259,10 +163,10 @@ def _to_resource_config(self) -> shapes.ResourceConfig:
259163
return shapes.ResourceConfig(**filtered_dict)
260164

261165

262-
class NetworkingConfig(shapes.VpcConfig):
263-
"""NetworkingConfig.
166+
class Networking(shapes.VpcConfig):
167+
"""Networking.
264168
265-
The NetworkingConifg is a subclass of `sagemaker_core.shapes.VpcConfig ` and
169+
The Networking class is a subclass of `sagemaker_core.shapes.VpcConfig ` and
266170
allows the user to specify the networking configuration for the training job.
267171
268172
Attributes:
@@ -290,7 +194,7 @@ class NetworkingConfig(shapes.VpcConfig):
290194
enable_inter_container_traffic_encryption: Optional[bool] = None
291195

292196
@model_validator(mode="after")
293-
def _model_validator(self) -> "NetworkingConfig":
197+
def _model_validator(self) -> "Networking":
294198
"""Convert Unassigned values to None."""
295199
return convert_unassigned_to_none(self)
296200

src/sagemaker/modules/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
os.path.dirname(os.path.abspath(__file__)), "train/container_drivers"
2626
)
2727

28-
SOURCE_CODE_CONFIG_JSON = "sourcecodeconfig.json"
29-
DISTRIBUTION_JSON = "distribution.json"
28+
SOURCE_CODE_JSON = "sourcecode.json"
29+
DISTRIBUTED_RUNNER_JSON = "distributed_runner.json"
3030
TRAIN_SCRIPT = "sm_train.sh"
3131

3232
DEFAULT_CONTAINER_ENTRYPOINT = ["/bin/bash"]

src/sagemaker/modules/distributed.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Distributed module."""
14+
from __future__ import absolute_import
15+
16+
from typing import Optional, Dict, Any, List
17+
from pydantic import BaseModel, PrivateAttr
18+
19+
20+
class DistributedRunner(BaseModel):
21+
"""Base class for DistributedRunner Class"""
22+
23+
_type: str = PrivateAttr()
24+
25+
def model_dump(self, *args, **kwargs):
26+
"""Dump the model to a dictionary."""
27+
result = super().model_dump(*args, **kwargs)
28+
result["_type"] = self._type
29+
return result
30+
31+
32+
class Torchrun(DistributedRunner):
33+
"""TorchDistribution.
34+
35+
The TorchDistribution runner uses `torchrun` or `torch.distributed.launch` in the backend to
36+
launch distributed training.
37+
38+
Attributes:
39+
process_count_per_node (int):
40+
The number of processes to run on each node in the training job.
41+
Will default to the number of GPUs available in the container.
42+
"""
43+
44+
_type: str = PrivateAttr(default="torchrun")
45+
46+
process_count_per_node: Optional[int] = None
47+
48+
49+
class TorchrunSMP(DistributedRunner):
50+
"""TorchrunSMP.
51+
52+
The TorchrunSMP runner uses `torchrun` or `torch.distributed.launch` in the backend
53+
to launch distributed training. This strategy is used for a PyTorch job using the SageMaker
54+
Model Parallelism library v2. For more information on the model parallelism parameters, see:
55+
https://docs.aws.amazon.com/sagemaker/latest/dg/distributed-model-parallel-v2-reference.html#distributed-model-parallel-v2-reference-init-config
56+
57+
Attributes:
58+
process_count_per_node (int):
59+
The number of processes to run on each node in the training job.
60+
Will default to the number of GPUs available in the container.
61+
hybrid_shard_degree (Optional[int]):
62+
Specifies a sharded parallelism degree for the model.
63+
sm_activation_offloading (Optional[bool]):
64+
Specifies whether to enable the SMP activation offloading implementation.
65+
activation_loading_horizon (Optional[int]):
66+
An integer specifying the activation offloading horizon type for FSDP. This is the
67+
maximum number of checkpointed or offloaded layers whose inputs can be in the GPU
68+
memory simultaneously.
69+
fsdp_cache_flush_warnings (Optional[bool]):
70+
Detects and warns if cache flushes happen in the PyTorch memory manager, because they
71+
can degrade computational performance.
72+
allow_empty_shards (Optional[bool]):
73+
Whether to allow empty shards when sharding tensors if tensor is not divisible. This is
74+
an experimental fix for crash during checkpointing in certain scenarios. Disabling this
75+
falls back to the original PyTorch behavior.
76+
tensor_parallel_degree (Optional[int]):
77+
Specifies a tensor parallelism degree. The value must be between 1 and world_size.
78+
context_parallel_degree (Optional[int]):
79+
Specifies the context parallelism degree. The value must be between 1 and world_size ,
80+
and must be <= hybrid_shard_degree.
81+
expert_parallel_degree (Optional[int]):
82+
Specifies a expert parallelism degree. The value must be between 1 and world_size.
83+
random_seed (Optional[int]):
84+
A seed number for the random operations in distributed modules by SMP tensor
85+
parallelism or expert parallelism.
86+
"""
87+
88+
_type: str = PrivateAttr(default="torchrun")
89+
90+
process_count_per_node: Optional[int] = None
91+
hybrid_shard_degree: Optional[int] = None
92+
sm_activation_offloading: Optional[bool] = None
93+
activation_loading_horizon: Optional[int] = None
94+
fsdp_cache_flush_warnings: Optional[bool] = None
95+
allow_empty_shards: Optional[bool] = None
96+
tensor_parallel_degree: Optional[int] = None
97+
context_parallel_degree: Optional[int] = None
98+
expert_parallel_degree: Optional[int] = None
99+
random_seed: Optional[int] = None
100+
101+
def _to_mp_parameters_dict(self) -> Dict[str, Any]:
102+
"""Convert to a dictionary of MP parameters."""
103+
mp_parameters = self.model_dump(exclude_none=True)
104+
mp_parameters.pop("_type")
105+
if mp_parameters.get("process_count_per_node") is not None:
106+
mp_parameters.pop("process_count_per_node")
107+
return mp_parameters
108+
109+
110+
class MPI(DistributedRunner):
111+
"""MPI.
112+
113+
The MPI runner uses `mpirun` in the backend to launch distributed training.
114+
115+
Attributes:
116+
process_count_per_node (int):
117+
The number of processes to run on each node in the training job.
118+
Will default to the number of GPUs available in the container.
119+
mpi_additional_options (Optional[str]):
120+
The custom MPI options to use for the training job.
121+
"""
122+
123+
_type: str = PrivateAttr(default="mpi")
124+
125+
process_count_per_node: Optional[int] = None
126+
mpi_additional_options: Optional[List[str]] = None

src/sagemaker/modules/templates.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
eval $CMD
2020
"""
2121

22-
EXECUTE_PYTORCH_DRIVER = """
23-
echo "Running PyTorch training driver"
24-
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/pytorch_driver.py
22+
EXEUCTE_TORCHRUN_DRIVER = """
23+
echo "Running Torchrun driver"
24+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/torchrun_driver.py
2525
"""
2626

2727
EXECUTE_MPI_DRIVER = """
28-
echo "Running MPI training driver"
28+
echo "Running MPI driver"
2929
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/mpi_driver.py
3030
"""
3131

@@ -73,12 +73,12 @@
7373
cat /opt/ml/input/config/inputdataconfig.json
7474
echo
7575
76-
echo "/opt/ml/input/data/sm_drivers/sourcecodeconfig.json"
77-
cat /opt/ml/input/data/sm_drivers/sourcecodeconfig.json
76+
echo "/opt/ml/input/data/sm_drivers/sourcecode.json"
77+
cat /opt/ml/input/data/sm_drivers/sourcecode.json
7878
echo
7979
80-
echo "/opt/ml/input/data/sm_drivers/distribution.json"
81-
cat /opt/ml/input/data/sm_drivers/distribution.json
80+
echo "/opt/ml/input/data/sm_drivers/distributed_runner.json"
81+
cat /opt/ml/input/data/sm_drivers/distributed_runner.json
8282
echo
8383
8484
echo "Setting up environment variables"

0 commit comments

Comments
 (0)