21
21
22
22
from __future__ import absolute_import
23
23
24
- from typing import Optional , Union , Dict , Any , List
24
+ from typing import Optional , Union
25
25
from pydantic import BaseModel , model_validator
26
26
27
27
import sagemaker_core .shapes as shapes
54
54
CheckpointConfig ,
55
55
)
56
56
57
- from sagemaker .modules import logger
58
57
from sagemaker .modules .utils import convert_unassigned_to_none
59
58
60
59
__all__ = [
61
- "SourceCodeConfig" ,
62
- "TorchDistributionConfig" ,
63
- "MPIDistributionConfig" ,
64
- "SMDistributedSettings" ,
65
- "DistributionConfig" ,
60
+ "SourceCode" ,
66
61
"StoppingCondition" ,
67
62
"RetryStrategy" ,
68
63
"OutputDataConfig" ,
87
82
"InstanceGroup" ,
88
83
"TensorBoardOutputConfig" ,
89
84
"CheckpointConfig" ,
90
- "ComputeConfig " ,
91
- "NetworkingConfig " ,
85
+ "Compute " ,
86
+ "Networking " ,
92
87
"InputData" ,
93
88
]
94
89
95
90
96
- class SMDistributedSettings (BaseModel ):
97
- """SMDistributedSettings .
91
+ class SourceCode (BaseModel ):
92
+ """SourceCode .
98
93
99
- The SMDistributedSettings is used to configure distributed training when
100
- using the smdistributed library.
101
-
102
- Attributes:
103
- enable_dataparallel (Optional[bool]):
104
- Whether to enable data parallelism.
105
- enable_modelparallel (Optional[bool]):
106
- Whether to enable model parallelism.
107
- modelparallel_parameters (Optional[Dict[str, Any]]):
108
- The parameters for model parallelism.
109
- """
110
-
111
- enable_dataparallel : Optional [bool ] = False
112
- enable_modelparallel : Optional [bool ] = False
113
- modelparallel_parameters : Optional [Dict [str , Any ]] = None
114
-
115
-
116
- class DistributionConfig (BaseModel ):
117
- """Base class for distribution configurations."""
118
-
119
- _distribution_type : str
120
-
121
-
122
- class TorchDistributionConfig (DistributionConfig ):
123
- """TorchDistributionConfig.
124
-
125
- The TorchDistributionConfig uses `torchrun` or `torch.distributed.launch` in the backend to
126
- launch distributed training.
127
-
128
- SMDistributed Library Information:
129
- - `TorchDistributionConfig` can be used for SMModelParallel V2.
130
- - For SMDataParallel or SMModelParallel V1, it is recommended to use the
131
- `MPIDistributionConfig.`
132
-
133
-
134
- Attributes:
135
- smdistributed_settings (Optional[SMDistributedSettings]):
136
- The settings for smdistributed library.
137
- process_count_per_node (int):
138
- The number of processes to run on each node in the training job.
139
- Will default to the number of CPUs or GPUs available in the container.
140
- """
141
-
142
- _distribution_type : str = "torch_distributed"
143
-
144
- smdistributed_settings : Optional [SMDistributedSettings ] = None
145
- process_count_per_node : Optional [int ] = None
146
-
147
- @model_validator (mode = "after" )
148
- def _validate_model (cls , model ): # pylint: disable=E0213
149
- """Validate the model."""
150
- if (
151
- getattr (model , "smddistributed_settings" , None )
152
- and model .smddistributed_settings .enable_dataparallel
153
- ):
154
- logger .warning (
155
- "For smdistributed data parallelism, it is recommended to use "
156
- + "MPIDistributionConfig."
157
- )
158
- return model
159
-
160
-
161
- class MPIDistributionConfig (DistributionConfig ):
162
- """MPIDistributionConfig.
163
-
164
- The MPIDistributionConfig uses `mpirun` in the backend to launch distributed training.
165
-
166
- SMDistributed Library Information:
167
- - `MPIDistributionConfig` can be used for SMDataParallel and SMModelParallel V1.
168
- - For SMModelParallel V2, it is recommended to use the `TorchDistributionConfig`.
169
-
170
- Attributes:
171
- smdistributed_settings (Optional[SMDistributedSettings]):
172
- The settings for smdistributed library.
173
- process_count_per_node (int):
174
- The number of processes to run on each node in the training job.
175
- Will default to the number of CPUs or GPUs available in the container.
176
- mpi_additional_options (Optional[str]):
177
- The custom MPI options to use for the training job.
178
- """
179
-
180
- _distribution_type : str = "mpi"
181
-
182
- smdistributed_settings : Optional [SMDistributedSettings ] = None
183
- process_count_per_node : Optional [int ] = None
184
- mpi_additional_options : Optional [List [str ]] = None
185
-
186
-
187
- class SourceCodeConfig (BaseModel ):
188
- """SourceCodeConfig.
189
-
190
- This config allows the user to specify the source code location, dependencies,
94
+ The SourceCode class allows the user to specify the source code location, dependencies,
191
95
entry script, or commands to be executed in the training job container.
192
96
193
97
Attributes:
@@ -210,10 +114,10 @@ class SourceCodeConfig(BaseModel):
210
114
command : Optional [str ] = None
211
115
212
116
213
- class ComputeConfig (shapes .ResourceConfig ):
214
- """ComputeConfig .
117
+ class Compute (shapes .ResourceConfig ):
118
+ """Compute .
215
119
216
- The ComputeConfig is a subclass of `sagemaker_core.shapes.ResourceConfig`
120
+ The Compute class is a subclass of `sagemaker_core.shapes.ResourceConfig`
217
121
and allows the user to specify the compute resources for the training job.
218
122
219
123
Attributes:
@@ -245,7 +149,7 @@ class ComputeConfig(shapes.ResourceConfig):
245
149
enable_managed_spot_training : Optional [bool ] = None
246
150
247
151
@model_validator (mode = "after" )
248
- def _model_validator (self ) -> "ComputeConfig " :
152
+ def _model_validator (self ) -> "Compute " :
249
153
"""Convert Unassigned values to None."""
250
154
return convert_unassigned_to_none (self )
251
155
@@ -259,10 +163,10 @@ def _to_resource_config(self) -> shapes.ResourceConfig:
259
163
return shapes .ResourceConfig (** filtered_dict )
260
164
261
165
262
- class NetworkingConfig (shapes .VpcConfig ):
263
- """NetworkingConfig .
166
+ class Networking (shapes .VpcConfig ):
167
+ """Networking .
264
168
265
- The NetworkingConifg is a subclass of `sagemaker_core.shapes.VpcConfig ` and
169
+ The Networking class is a subclass of `sagemaker_core.shapes.VpcConfig ` and
266
170
allows the user to specify the networking configuration for the training job.
267
171
268
172
Attributes:
@@ -290,7 +194,7 @@ class NetworkingConfig(shapes.VpcConfig):
290
194
enable_inter_container_traffic_encryption : Optional [bool ] = None
291
195
292
196
@model_validator (mode = "after" )
293
- def _model_validator (self ) -> "NetworkingConfig " :
197
+ def _model_validator (self ) -> "Networking " :
294
198
"""Convert Unassigned values to None."""
295
199
return convert_unassigned_to_none (self )
296
200
0 commit comments