|
19 | 19 | import shutil
|
20 | 20 | from tempfile import TemporaryDirectory
|
21 | 21 |
|
22 |
| -from typing import Optional, List, Union, Dict, Any |
| 22 | +from typing import Optional, List, Union, Dict, Any, ClassVar |
| 23 | + |
| 24 | +from graphene.utils.str_converters import to_camel_case, to_snake_case |
| 25 | + |
23 | 26 | from sagemaker_core.main import resources
|
24 | 27 | from sagemaker_core.resources import TrainingJob
|
25 | 28 | from sagemaker_core.shapes import AlgorithmSpecification
|
26 | 29 |
|
27 | 30 | from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call
|
28 | 31 |
|
| 32 | +from sagemaker.config.config_schema import (_simple_path, SAGEMAKER, |
| 33 | + MODEL_TRAINER, MODULES, |
| 34 | + PYTHON_SDK, |
| 35 | + TRAINING_JOB_ENVIRONMENT_PATH, |
| 36 | + TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, |
| 37 | + TRAINING_JOB_VPC_CONFIG_PATH, |
| 38 | + TRAINING_JOB_SUBNETS_PATH, |
| 39 | + TRAINING_JOB_SECURITY_GROUP_IDS_PATH, |
| 40 | + TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, |
| 41 | + TRAINING_JOB_PROFILE_CONFIG_PATH, |
| 42 | + TRAINING_JOB_RESOURCE_CONFIG_PATH, |
| 43 | + TRAINING_JOB_ROLE_ARN_PATH, |
| 44 | + TRAINING_JOB_TAGS_PATH) |
| 45 | + |
| 46 | +from sagemaker.utils import resolve_value_from_config |
29 | 47 | from sagemaker.modules import Session, get_execution_role
|
30 | 48 | from sagemaker.modules.configs import (
|
31 | 49 | Compute,
|
@@ -204,6 +222,125 @@ class ModelTrainer(BaseModel):
|
204 | 222 | tags: Optional[List[Tag]] = None
|
205 | 223 | local_container_root: Optional[str] = os.getcwd()
|
206 | 224 |
|
| 225 | + CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = ["role", |
| 226 | + "base_job_name", |
| 227 | + "source_code", |
| 228 | + "distributed_runner", |
| 229 | + "compute", |
| 230 | + "networking", |
| 231 | + "stopping_condition", |
| 232 | + "training_image", |
| 233 | + "training_image_config", |
| 234 | + "algorithm_name", |
| 235 | + "output_data_config", |
| 236 | + "checkpoint_config", |
| 237 | + "training_input_mode", |
| 238 | + "environment", |
| 239 | + "hyperparameters"] |
| 240 | + |
| 241 | + SERIALIZABLE_CONFIG_ATTRIBUTES: ClassVar[Any] = { |
| 242 | + "source_code": SourceCode, |
| 243 | + "distributed_runner": type(DistributedRunner), |
| 244 | + "compute": type(Compute), |
| 245 | + "networking": type(Networking), |
| 246 | + "stopping_condition": type(StoppingCondition), |
| 247 | + "training_image_config": type(TrainingImageConfig), |
| 248 | + "output_data_config": type(OutputDataConfig), |
| 249 | + "checkpoint_config": type(CheckpointConfig) |
| 250 | + } |
| 251 | + |
| 252 | + def _populate_intelligent_defaults(self): |
| 253 | + """Function to populate all the possible default configs |
| 254 | +
|
| 255 | + Model Trainer specific configs take precedence over the generic training job ones. |
| 256 | + """ |
| 257 | + self._populate_intelligent_defaults_from_model_trainer_space() |
| 258 | + self._populate_intelligent_defaults_from_training_job_space() |
| 259 | + |
| 260 | + def _populate_intelligent_defaults_from_training_job_space(self): |
| 261 | + """Function to populate all the possible default configs from Training Job Space""" |
| 262 | + if not self.environment: |
| 263 | + self.environment = resolve_value_from_config( |
| 264 | + config_path=TRAINING_JOB_ENVIRONMENT_PATH, |
| 265 | + sagemaker_session=self.sagemaker_session) |
| 266 | + |
| 267 | + default_enable_network_isolation = resolve_value_from_config( |
| 268 | + config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, |
| 269 | + sagemaker_session=self.sagemaker_session) |
| 270 | + default_vpc_config = resolve_value_from_config( |
| 271 | + config_path=TRAINING_JOB_VPC_CONFIG_PATH, |
| 272 | + sagemaker_session=self.sagemaker_session) |
| 273 | + |
| 274 | + if not self.networking: |
| 275 | + if (default_enable_network_isolation is not None |
| 276 | + or default_vpc_config is not None): |
| 277 | + self.networking = Networking( |
| 278 | + default_enable_network_isolation=default_enable_network_isolation, |
| 279 | + subnets=resolve_value_from_config(config_path=TRAINING_JOB_SUBNETS_PATH), |
| 280 | + security_group_ids=resolve_value_from_config( |
| 281 | + config_path=TRAINING_JOB_SECURITY_GROUP_IDS_PATH), |
| 282 | + ) |
| 283 | + else: |
| 284 | + if self.networking.enable_network_isolation is None: |
| 285 | + self.networking.enable_network_isolation = default_enable_network_isolation |
| 286 | + if self.networking.subnets is None: |
| 287 | + self.networking.subnets = ( |
| 288 | + resolve_value_from_config(config_path=TRAINING_JOB_SUBNETS_PATH)) |
| 289 | + if self.networking.security_group_ids is None: |
| 290 | + self.networking.subnets = ( |
| 291 | + resolve_value_from_config(config_path=TRAINING_JOB_SUBNETS_PATH)) |
| 292 | + |
| 293 | + if not self.output_data_config: |
| 294 | + default_output_data_config = resolve_value_from_config( |
| 295 | + config_path=TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH) |
| 296 | + if default_output_data_config: |
| 297 | + self.output_data_config = OutputDataConfig( |
| 298 | + **self._convert_keys_to_snake(default_output_data_config)) |
| 299 | + |
| 300 | + if not self._profiler_config: |
| 301 | + default_profiler_config = resolve_value_from_config( |
| 302 | + config_path=TRAINING_JOB_PROFILE_CONFIG_PATH) |
| 303 | + if default_profiler_config: |
| 304 | + self._profiler_config = ProfilerConfig( |
| 305 | + **self._convert_keys_to_snake(default_profiler_config)) |
| 306 | + |
| 307 | + if not self.compute: |
| 308 | + default_resource_config = resolve_value_from_config( |
| 309 | + config_path=TRAINING_JOB_RESOURCE_CONFIG_PATH) |
| 310 | + if default_resource_config: |
| 311 | + self.compute = Compute(**self._convert_keys_to_snake(default_resource_config)) |
| 312 | + |
| 313 | + if not self.role: |
| 314 | + self.role = resolve_value_from_config(config_path=TRAINING_JOB_ROLE_ARN_PATH) |
| 315 | + |
| 316 | + if not self.tags: |
| 317 | + self.tags = resolve_value_from_config(config_path=TRAINING_JOB_TAGS_PATH) |
| 318 | + |
| 319 | + def _convert_keys_to_snake(self, config: dict) -> dict: |
| 320 | + """Utility helper function that converts the keys of a dictionary into snake case""" |
| 321 | + return { |
| 322 | + to_snake_case(key): value |
| 323 | + for key, value in config.items() |
| 324 | + } |
| 325 | + |
| 326 | + def _populate_intelligent_defaults_from_model_trainer_space(self): |
| 327 | + """Function to populate all the possible default configs from Model Trainer Space""" |
| 328 | + |
| 329 | + for configurable_attribute in self.CONFIGURABLE_ATTRIBUTES: |
| 330 | + if getattr(self, configurable_attribute) is None: |
| 331 | + default_config = resolve_value_from_config( |
| 332 | + config_path=_simple_path(SAGEMAKER, |
| 333 | + PYTHON_SDK, |
| 334 | + MODULES, |
| 335 | + MODEL_TRAINER, |
| 336 | + to_camel_case(configurable_attribute)), |
| 337 | + sagemaker_session=self.sagemaker_session) |
| 338 | + if default_config is not None: |
| 339 | + if configurable_attribute in self.SERIALIZABLE_CONFIG_ATTRIBUTES: |
| 340 | + default_config = (self.SERIALIZABLE_CONFIG_ATTRIBUTES |
| 341 | + .get(configurable_attribute)(**default_config)) # noqa |
| 342 | + setattr(self, configurable_attribute, default_config) |
| 343 | + |
207 | 344 | # Created Artifacts
|
208 | 345 | _latest_training_job: Optional[resources.TrainingJob] = PrivateAttr(default=None)
|
209 | 346 |
|
@@ -374,6 +511,7 @@ def train(
|
374 | 511 | Whether to display the training container logs while training.
|
375 | 512 | Defaults to True.
|
376 | 513 | """
|
| 514 | + self._populate_intelligent_defaults() |
377 | 515 | if input_data_config:
|
378 | 516 | self.input_data_config = input_data_config
|
379 | 517 |
|
@@ -745,7 +883,8 @@ def with_debugger_settings(
|
745 | 883 | debug_hook_config (Optional[DebugHookConfig]):
|
746 | 884 | Configuration information for the Amazon SageMaker Debugger hook parameters,
|
747 | 885 | metric and tensor collections, and storage paths.
|
748 |
| - To learn more see: https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-createtrainingjob-api.html |
| 886 | + To learn more see: |
| 887 | + https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-createtrainingjob-api.html |
749 | 888 | debug_rule_configurations (Optional[List[DebugRuleConfiguration]]):
|
750 | 889 | Configuration information for Amazon SageMaker Debugger rules for debugging
|
751 | 890 | output ensors.
|
|
0 commit comments