21
21
22
22
from __future__ import absolute_import
23
23
24
- from typing import Optional
25
- from pydantic import BaseModel
24
+ from typing import Optional , Dict , Any , List
25
+ from pydantic import BaseModel , model_validator
26
26
27
27
from sagemaker_core .shapes import (
28
28
ResourceConfig ,
36
36
VpcConfig ,
37
37
)
38
38
39
+ from sagemaker .modules import logger
40
+
39
41
__all__ = [
40
42
"SourceCodeConfig" ,
41
43
"ResourceConfig" ,
50
52
]
51
53
52
54
55
+ class SMDistributedSettings (BaseModel ):
56
+ """SMDistributedSettings.
57
+
58
+ The SMDistributedSettings is used to configure distributed training when
59
+ using the smdistributed library.
60
+
61
+ Attributes:
62
+ enable_dataparallel (Optional[bool]):
63
+ Whether to enable data parallelism.
64
+ enable_modelparallel (Optional[bool]):
65
+ Whether to enable model parallelism.
66
+ modelparallel_parameters (Optional[Dict[str, Any]]):
67
+ The parameters for model parallelism.
68
+ """
69
+
70
+ enable_dataparallel : Optional [bool ] = False
71
+ enable_modelparallel : Optional [bool ] = False
72
+ modelparallel_parameters : Optional [Dict [str , Any ]] = None
73
+
74
+
75
+ class DistributionConfig (BaseModel ):
76
+ """Base class for distribution configurations."""
77
+
78
+ _distribution_type : str
79
+
80
+
81
+ class TorchDistributionConfig (DistributionConfig ):
82
+ """TorchDistributionConfig.
83
+
84
+ The TorchDistributionConfig uses `torchrun` or `torch.distributed.launch` in the backend to
85
+ launch distributed training.
86
+
87
+ SMDistributed Library Information:
88
+ - `TorchDistributionConfig` can be used for SMModelParallel V2.
89
+ - For SMDataParallel or SMModelParallel V1, it is recommended to use the
90
+ `MPIDistributionConfig.`
91
+
92
+
93
+ Attributes:
94
+ smdistributed_settings (Optional[SMDistributedSettings]):
95
+ The settings for smdistributed library.
96
+ process_count_per_node (int):
97
+ The number of processes to run on each node in the training job.
98
+ Will default to the number of CPUs or GPUs available in the container.
99
+ """
100
+
101
+ _distribution_type : str = "torch_distributed"
102
+
103
+ smdistributed_settings : Optional [SMDistributedSettings ] = None
104
+ process_count_per_node : Optional [int ] = None
105
+
106
+ @model_validator (mode = "after" )
107
+ def _validate_model (cls , model ): # pylint: disable=E0213
108
+ """Validate the model."""
109
+ if (
110
+ getattr (model , "smddistributed_settings" , None )
111
+ and model .smddistributed_settings .enable_dataparallel
112
+ ):
113
+ logger .warning (
114
+ "For smdistributed data parallelism, it is recommended to use "
115
+ + "MPIDistributionConfig."
116
+ )
117
+ return model
118
+
119
+
120
+ class MPIDistributionConfig (DistributionConfig ):
121
+ """MPIDistributionConfig.
122
+
123
+ The MPIDistributionConfig uses `mpirun` in the backend to launch distributed training.
124
+
125
+ SMDistributed Library Information:
126
+ - `MPIDistributionConfig` can be used for SMDataParallel and SMModelParallel V1.
127
+ - For SMModelParallel V2, it is recommended to use the `TorchDistributionConfig`.
128
+
129
+ Attributes:
130
+ smdistributed_settings (Optional[SMDistributedSettings]):
131
+ The settings for smdistributed library.
132
+ process_count_per_node (int):
133
+ The number of processes to run on each node in the training job.
134
+ Will default to the number of CPUs or GPUs available in the container.
135
+ mpi_additional_options (Optional[str]):
136
+ The custom MPI options to use for the training job.
137
+ """
138
+
139
+ _distribution_type : str = "mpi"
140
+
141
+ smdistributed_settings : Optional [SMDistributedSettings ] = None
142
+ process_count_per_node : Optional [int ] = None
143
+ mpi_additional_options : Optional [List [str ]] = None
144
+
145
+
53
146
class SourceCodeConfig (BaseModel ):
54
147
"""SourceCodeConfig.
55
148
56
149
This config allows the user to specify the source code location, dependencies,
57
150
entry script, or commands to be executed in the training job container.
58
151
59
152
Attributes:
60
- command (Optional[str]):
61
- The command(s) to execute in the training job container. Example: "python my_script.py".
62
- If not specified, entry_script must be provided
63
153
source_dir (Optional[str]):
64
154
The local directory containing the source code to be used in the training job container.
65
155
requirements (Optional[str]):
@@ -68,9 +158,17 @@ class SourceCodeConfig(BaseModel):
68
158
entry_script (Optional[str]):
69
159
The path within `source_dir` to the entry script that will be executed in the training
70
160
job container. If not specified, command must be provided.
161
+ command (Optional[str]):
162
+ The command(s) to execute in the training job container. Example: "python my_script.py".
163
+ If not specified, entry_script must be provided.
164
+ distribution (Optional[Union[
165
+ MPIDistributionConfig,
166
+ TorchDistributionConfig,
167
+ ]]):
168
+ The distribution configuration for the training job.
71
169
"""
72
170
73
- command : Optional [str ] = None
74
171
source_dir : Optional [str ] = None
75
172
requirements : Optional [str ] = None
76
173
entry_script : Optional [str ] = None
174
+ command : Optional [str ] = None
0 commit comments