Skip to content

Commit 25486e6

Browse files
beniericpintaoz-aws
authored andcommitted
Add Distributed Training Support Model Trainer (#1536)
1 parent 4cd65a5 commit 25486e6

21 files changed

+1565
-243
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ env/
3232
.python-version
3333
*.html
3434
**/_repack_script_launcher.sh
35-
src/sagemaker/modules/scripts/train.sh
35+
src/sagemaker/modules/train/container_drivers/sm_train.sh
36+
src/sagemaker/modules/train/container_drivers/sourcecodeconfig.json
3637
tests/data/**/_repack_model.py
3738
tests/data/experiment/sagemaker-dev-1.0.tar.gz
3839
src/sagemaker/serve/tmp_workspace

src/sagemaker/modules/configs.py

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
from __future__ import absolute_import
2323

24-
from typing import Optional
25-
from pydantic import BaseModel
24+
from typing import Optional, Dict, Any, List
25+
from pydantic import BaseModel, model_validator
2626

2727
from sagemaker_core.shapes import (
2828
ResourceConfig,
@@ -36,6 +36,8 @@
3636
VpcConfig,
3737
)
3838

39+
from sagemaker.modules import logger
40+
3941
__all__ = [
4042
"SourceCodeConfig",
4143
"ResourceConfig",
@@ -50,16 +52,104 @@
5052
]
5153

5254

55+
class SMDistributedSettings(BaseModel):
56+
"""SMDistributedSettings.
57+
58+
The SMDistributedSettings is used to configure distributed training when
59+
using the smdistributed library.
60+
61+
Attributes:
62+
enable_dataparallel (Optional[bool]):
63+
Whether to enable data parallelism.
64+
enable_modelparallel (Optional[bool]):
65+
Whether to enable model parallelism.
66+
modelparallel_parameters (Optional[Dict[str, Any]]):
67+
The parameters for model parallelism.
68+
"""
69+
70+
enable_dataparallel: Optional[bool] = False
71+
enable_modelparallel: Optional[bool] = False
72+
modelparallel_parameters: Optional[Dict[str, Any]] = None
73+
74+
75+
class DistributionConfig(BaseModel):
76+
"""Base class for distribution configurations."""
77+
78+
_distribution_type: str
79+
80+
81+
class TorchDistributionConfig(DistributionConfig):
82+
"""TorchDistributionConfig.
83+
84+
The TorchDistributionConfig uses `torchrun` or `torch.distributed.launch` in the backend to
85+
launch distributed training.
86+
87+
SMDistributed Library Information:
88+
- `TorchDistributionConfig` can be used for SMModelParallel V2.
89+
- For SMDataParallel or SMModelParallel V1, it is recommended to use the
90+
`MPIDistributionConfig.`
91+
92+
93+
Attributes:
94+
smdistributed_settings (Optional[SMDistributedSettings]):
95+
The settings for smdistributed library.
96+
process_count_per_node (int):
97+
The number of processes to run on each node in the training job.
98+
Will default to the number of CPUs or GPUs available in the container.
99+
"""
100+
101+
_distribution_type: str = "torch_distributed"
102+
103+
smdistributed_settings: Optional[SMDistributedSettings] = None
104+
process_count_per_node: Optional[int] = None
105+
106+
@model_validator(mode="after")
107+
def _validate_model(cls, model): # pylint: disable=E0213
108+
"""Validate the model."""
109+
if (
110+
getattr(model, "smddistributed_settings", None)
111+
and model.smddistributed_settings.enable_dataparallel
112+
):
113+
logger.warning(
114+
"For smdistributed data parallelism, it is recommended to use "
115+
+ "MPIDistributionConfig."
116+
)
117+
return model
118+
119+
120+
class MPIDistributionConfig(DistributionConfig):
121+
"""MPIDistributionConfig.
122+
123+
The MPIDistributionConfig uses `mpirun` in the backend to launch distributed training.
124+
125+
SMDistributed Library Information:
126+
- `MPIDistributionConfig` can be used for SMDataParallel and SMModelParallel V1.
127+
- For SMModelParallel V2, it is recommended to use the `TorchDistributionConfig`.
128+
129+
Attributes:
130+
smdistributed_settings (Optional[SMDistributedSettings]):
131+
The settings for smdistributed library.
132+
process_count_per_node (int):
133+
The number of processes to run on each node in the training job.
134+
Will default to the number of CPUs or GPUs available in the container.
135+
mpi_additional_options (Optional[str]):
136+
The custom MPI options to use for the training job.
137+
"""
138+
139+
_distribution_type: str = "mpi"
140+
141+
smdistributed_settings: Optional[SMDistributedSettings] = None
142+
process_count_per_node: Optional[int] = None
143+
mpi_additional_options: Optional[List[str]] = None
144+
145+
53146
class SourceCodeConfig(BaseModel):
54147
"""SourceCodeConfig.
55148
56149
This config allows the user to specify the source code location, dependencies,
57150
entry script, or commands to be executed in the training job container.
58151
59152
Attributes:
60-
command (Optional[str]):
61-
The command(s) to execute in the training job container. Example: "python my_script.py".
62-
If not specified, entry_script must be provided
63153
source_dir (Optional[str]):
64154
The local directory containing the source code to be used in the training job container.
65155
requirements (Optional[str]):
@@ -68,9 +158,17 @@ class SourceCodeConfig(BaseModel):
68158
entry_script (Optional[str]):
69159
The path within `source_dir` to the entry script that will be executed in the training
70160
job container. If not specified, command must be provided.
161+
command (Optional[str]):
162+
The command(s) to execute in the training job container. Example: "python my_script.py".
163+
If not specified, entry_script must be provided.
164+
distribution (Optional[Union[
165+
MPIDistributionConfig,
166+
TorchDistributionConfig,
167+
]]):
168+
The distribution configuration for the training job.
71169
"""
72170

73-
command: Optional[str] = None
74171
source_dir: Optional[str] = None
75172
requirements: Optional[str] = None
76173
entry_script: Optional[str] = None
174+
command: Optional[str] = None

src/sagemaker/modules/constants.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,21 @@
1616

1717
DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge"
1818

19-
SOURCE_CODE_CONTAINER_PATH = "/opt/ml/input/data/code"
20-
19+
SM_CODE = "sm_code"
2120
SM_CODE_CONTAINER_PATH = "/opt/ml/input/data/sm_code"
22-
SM_CODE_LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts")
23-
TRAIN_SCRIPT = "train.sh"
21+
22+
SM_DRIVERS = "sm_drivers"
23+
SM_DRIVERS_CONTAINER_PATH = "/opt/ml/input/data/sm_drivers"
24+
SM_DRIVERS_LOCAL_PATH = os.path.join(
25+
os.path.dirname(os.path.abspath(__file__)), "train/container_drivers"
26+
)
27+
28+
SOURCE_CODE_CONFIG_JSON = "sourcecodeconfig.json"
29+
TRAIN_SCRIPT = "sm_train.sh"
2430

2531
DEFAULT_CONTAINER_ENTRYPOINT = ["/bin/bash"]
2632
DEFAULT_CONTAINER_ARGUMENTS = [
2733
"-c",
28-
f"chmod +x {SM_CODE_CONTAINER_PATH}/{TRAIN_SCRIPT} "
29-
+ f"&& {SM_CODE_CONTAINER_PATH}/{TRAIN_SCRIPT}",
34+
f"chmod +x {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT} "
35+
+ f"&& {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT}",
3036
]

src/sagemaker/modules/templates.py

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,58 @@
1313
"""Templates module."""
1414
from __future__ import absolute_import
1515

16+
EXECUTE_BASE_COMMANDS = """
17+
CMD="{base_command}"
18+
echo "Running command: $CMD"
19+
eval $CMD
20+
"""
21+
22+
EXECUTE_PYTORCH_DRIVER = """
23+
echo "Running PyTorch training driver"
24+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/pytorch_driver.py
25+
"""
26+
27+
EXECUTE_MPI_DRIVER = """
28+
echo "Running MPI training driver"
29+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/mpi_driver.py
30+
"""
31+
1632
TRAIN_SCRIPT_TEMPLATE = """
1733
#!/bin/bash
34+
set -e
1835
echo "Starting training script"
1936
37+
handle_error() {{
38+
EXIT_STATUS=$?
39+
echo "An error occurred with exit code $EXIT_STATUS"
40+
if [ ! -s /opt/ml/output/failure ]; then
41+
echo "Training Execution failed. For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'.
42+
TrainingJob - $TRAINING_JOB_NAME" >> /opt/ml/output/failure
43+
fi
44+
exit $EXIT_STATUS
45+
}}
46+
47+
check_python() {{
48+
if command -v python3 &>/dev/null; then
49+
SM_PYTHON_CMD="python3"
50+
SM_PIP_CMD="pip3"
51+
echo "Found python3"
52+
elif command -v python &>/dev/null; then
53+
SM_PYTHON_CMD="python"
54+
SM_PIP_CMD="pip"
55+
echo "Found python"
56+
else
57+
echo "Python may not be installed"
58+
return 1
59+
fi
60+
}}
61+
62+
trap 'handle_error' ERR
63+
64+
check_python
65+
66+
$SM_PYTHON_CMD --version
67+
2068
echo "/opt/ml/input/config/resourceconfig.json:"
2169
cat /opt/ml/input/config/resourceconfig.json
2270
echo
@@ -29,27 +77,17 @@
2977
cat /opt/ml/input/config/hyperparameters.json
3078
echo
3179
80+
echo "/opt/ml/input/data/sm_drivers/sourcecodeconfig.json"
81+
cat /opt/ml/input/data/sm_drivers/sourcecodeconfig.json
82+
echo
83+
3284
echo "Setting up environment variables"
33-
python /opt/ml/input/data/sm_code/environment.py
34-
source /opt/ml/input/data/sm_code/sm_training.env
85+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/scripts/environment.py
86+
source /opt/ml/input/data/sm_drivers/scripts/sm_training.env
3587
36-
python --version
3788
{working_dir}
3889
{install_requirements}
39-
CMD="{command}"
40-
echo "Running command: $CMD"
41-
eval $CMD
42-
EXIT_STATUS=$?
90+
{execute_driver}
4391
44-
if [ $EXIT_STATUS -ne 0 ]; then
45-
echo "Command failed with exit status $EXIT_STATUS"
46-
if [ ! -s /opt/ml/output/failure ]; then
47-
echo "Command failed with exit code $EXIT_STATUS.
48-
For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'.
49-
TrainingJob - $TRAINING_JOB_NAME" >> /opt/ml/output/failure
50-
fi
51-
exit $EXIT_STATUS
52-
else
53-
echo "Command succeeded"
54-
fi
92+
echo "Training Container Execution Completed"
5593
"""

0 commit comments

Comments
 (0)