Skip to content

Commit 784a18f

Browse files
nargokulpintaoz-aws
authored andcommitted
Intelligent defaults for Model Trainer (#1586)
* Intelligent defaults for Model Trainer * Codestyle Fixes * Unit test fixes * New unit tests * Codestyle fixes * Codestyle fixes * Parity support with Estimator * Refactor * CodeStyle fixes * Update to use self.sagemaker_session instead * Codestyle checks * Fix notebooks
1 parent 0a488fe commit 784a18f

File tree

4 files changed

+316
-8
lines changed

4 files changed

+316
-8
lines changed

src/sagemaker/config/config_schema.py

+64
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
REGION_NAME = "region_name"
117117
TELEMETRY_OPT_OUT = "TelemetryOptOut"
118118
NOTEBOOK_JOB = "NotebookJob"
119+
MODEL_TRAINER = "ModelTrainer"
119120

120121

121122
def _simple_path(*args: str):
@@ -142,6 +143,7 @@ def _simple_path(*args: str):
142143
)
143144
TRAINING_JOB_ROLE_ARN_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, ROLE_ARN)
144145
TRAINING_JOB_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, VPC_CONFIG)
146+
TRAINING_JOB_TAGS_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, TAGS)
145147
TRAINING_JOB_SECURITY_GROUP_IDS_PATH = _simple_path(
146148
TRAINING_JOB_VPC_CONFIG_PATH, SECURITY_GROUP_IDS
147149
)
@@ -656,6 +658,64 @@ def _simple_path(*args: str):
656658
"minItems": 1,
657659
"maxItems": 15,
658660
},
661+
"role": {
662+
TYPE: "string",
663+
"pattern": r"^arn:aws[a-z\-]*:iam::\d{12}:role/?[a-zA-Z_0-9+=,.@\-_/]+$",
664+
"minLength": 20,
665+
"maxLength": 2048,
666+
},
667+
"baseJobName": {
668+
TYPE: OBJECT,
669+
ADDITIONAL_PROPERTIES: True
670+
},
671+
"sourceCode": {
672+
TYPE: OBJECT,
673+
ADDITIONAL_PROPERTIES: True
674+
},
675+
"distributed_runner": {
676+
TYPE: OBJECT,
677+
ADDITIONAL_PROPERTIES: True
678+
},
679+
"compute": {
680+
TYPE: OBJECT,
681+
ADDITIONAL_PROPERTIES: True
682+
},
683+
"networking": {
684+
TYPE: OBJECT,
685+
ADDITIONAL_PROPERTIES: True
686+
},
687+
"stoppingCondition": {
688+
TYPE: OBJECT,
689+
ADDITIONAL_PROPERTIES: True
690+
},
691+
"trainingImage": {
692+
TYPE: OBJECT,
693+
ADDITIONAL_PROPERTIES: True
694+
},
695+
"trainingImageConfig": {
696+
TYPE: OBJECT,
697+
ADDITIONAL_PROPERTIES: True
698+
},
699+
"algorithmName": {
700+
TYPE: OBJECT,
701+
ADDITIONAL_PROPERTIES: True
702+
},
703+
"outputDataConfig": {
704+
TYPE: OBJECT,
705+
ADDITIONAL_PROPERTIES: True
706+
},
707+
"trainingInputMode": {
708+
TYPE: OBJECT,
709+
ADDITIONAL_PROPERTIES: True
710+
},
711+
"environment": {
712+
TYPE: OBJECT,
713+
ADDITIONAL_PROPERTIES: True
714+
},
715+
"hyperparameters": {
716+
TYPE: OBJECT,
717+
ADDITIONAL_PROPERTIES: True
718+
},
659719
},
660720
PROPERTIES: {
661721
SCHEMA_VERSION: {
@@ -709,6 +769,10 @@ def _simple_path(*args: str):
709769
},
710770
},
711771
},
772+
MODEL_TRAINER: {
773+
TYPE: OBJECT,
774+
ADDITIONAL_PROPERTIES: True
775+
},
712776
ESTIMATOR: {
713777
TYPE: OBJECT,
714778
ADDITIONAL_PROPERTIES: False,

src/sagemaker/modules/testing_notebooks/base_model_trainer.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
]
1111
},
1212
{
13-
"cell_type": "markdown",
1413
"metadata": {},
14+
"cell_type": "markdown",
1515
"source": [
1616
"# ModelTrainer\n",
1717
"The ModelTrainer is a new interface for training designed to tackle many of the challenges that exist in todays Estimator class. Some key features include:\n",

src/sagemaker/modules/train/model_trainer.py

+141-2
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,31 @@
1919
import shutil
2020
from tempfile import TemporaryDirectory
2121

22-
from typing import Optional, List, Union, Dict, Any
22+
from typing import Optional, List, Union, Dict, Any, ClassVar
23+
24+
from graphene.utils.str_converters import to_camel_case, to_snake_case
25+
2326
from sagemaker_core.main import resources
2427
from sagemaker_core.resources import TrainingJob
2528
from sagemaker_core.shapes import AlgorithmSpecification
2629

2730
from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call
2831

32+
from sagemaker.config.config_schema import (_simple_path, SAGEMAKER,
33+
MODEL_TRAINER, MODULES,
34+
PYTHON_SDK,
35+
TRAINING_JOB_ENVIRONMENT_PATH,
36+
TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
37+
TRAINING_JOB_VPC_CONFIG_PATH,
38+
TRAINING_JOB_SUBNETS_PATH,
39+
TRAINING_JOB_SECURITY_GROUP_IDS_PATH,
40+
TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH,
41+
TRAINING_JOB_PROFILE_CONFIG_PATH,
42+
TRAINING_JOB_RESOURCE_CONFIG_PATH,
43+
TRAINING_JOB_ROLE_ARN_PATH,
44+
TRAINING_JOB_TAGS_PATH)
45+
46+
from sagemaker.utils import resolve_value_from_config
2947
from sagemaker.modules import Session, get_execution_role
3048
from sagemaker.modules.configs import (
3149
Compute,
@@ -204,6 +222,125 @@ class ModelTrainer(BaseModel):
204222
tags: Optional[List[Tag]] = None
205223
local_container_root: Optional[str] = os.getcwd()
206224

225+
CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = ["role",
226+
"base_job_name",
227+
"source_code",
228+
"distributed_runner",
229+
"compute",
230+
"networking",
231+
"stopping_condition",
232+
"training_image",
233+
"training_image_config",
234+
"algorithm_name",
235+
"output_data_config",
236+
"checkpoint_config",
237+
"training_input_mode",
238+
"environment",
239+
"hyperparameters"]
240+
241+
SERIALIZABLE_CONFIG_ATTRIBUTES: ClassVar[Any] = {
242+
"source_code": SourceCode,
243+
"distributed_runner": type(DistributedRunner),
244+
"compute": type(Compute),
245+
"networking": type(Networking),
246+
"stopping_condition": type(StoppingCondition),
247+
"training_image_config": type(TrainingImageConfig),
248+
"output_data_config": type(OutputDataConfig),
249+
"checkpoint_config": type(CheckpointConfig)
250+
}
251+
252+
def _populate_intelligent_defaults(self):
253+
"""Function to populate all the possible default configs
254+
255+
Model Trainer specific configs take precedence over the generic training job ones.
256+
"""
257+
self._populate_intelligent_defaults_from_model_trainer_space()
258+
self._populate_intelligent_defaults_from_training_job_space()
259+
260+
def _populate_intelligent_defaults_from_training_job_space(self):
261+
"""Function to populate all the possible default configs from Training Job Space"""
262+
if not self.environment:
263+
self.environment = resolve_value_from_config(
264+
config_path=TRAINING_JOB_ENVIRONMENT_PATH,
265+
sagemaker_session=self.sagemaker_session)
266+
267+
default_enable_network_isolation = resolve_value_from_config(
268+
config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
269+
sagemaker_session=self.sagemaker_session)
270+
default_vpc_config = resolve_value_from_config(
271+
config_path=TRAINING_JOB_VPC_CONFIG_PATH,
272+
sagemaker_session=self.sagemaker_session)
273+
274+
if not self.networking:
275+
if (default_enable_network_isolation is not None
276+
or default_vpc_config is not None):
277+
self.networking = Networking(
278+
default_enable_network_isolation=default_enable_network_isolation,
279+
subnets=resolve_value_from_config(config_path=TRAINING_JOB_SUBNETS_PATH),
280+
security_group_ids=resolve_value_from_config(
281+
config_path=TRAINING_JOB_SECURITY_GROUP_IDS_PATH),
282+
)
283+
else:
284+
if self.networking.enable_network_isolation is None:
285+
self.networking.enable_network_isolation = default_enable_network_isolation
286+
if self.networking.subnets is None:
287+
self.networking.subnets = (
288+
resolve_value_from_config(config_path=TRAINING_JOB_SUBNETS_PATH))
289+
if self.networking.security_group_ids is None:
290+
self.networking.subnets = (
291+
resolve_value_from_config(config_path=TRAINING_JOB_SUBNETS_PATH))
292+
293+
if not self.output_data_config:
294+
default_output_data_config = resolve_value_from_config(
295+
config_path=TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH)
296+
if default_output_data_config:
297+
self.output_data_config = OutputDataConfig(
298+
**self._convert_keys_to_snake(default_output_data_config))
299+
300+
if not self._profiler_config:
301+
default_profiler_config = resolve_value_from_config(
302+
config_path=TRAINING_JOB_PROFILE_CONFIG_PATH)
303+
if default_profiler_config:
304+
self._profiler_config = ProfilerConfig(
305+
**self._convert_keys_to_snake(default_profiler_config))
306+
307+
if not self.compute:
308+
default_resource_config = resolve_value_from_config(
309+
config_path=TRAINING_JOB_RESOURCE_CONFIG_PATH)
310+
if default_resource_config:
311+
self.compute = Compute(**self._convert_keys_to_snake(default_resource_config))
312+
313+
if not self.role:
314+
self.role = resolve_value_from_config(config_path=TRAINING_JOB_ROLE_ARN_PATH)
315+
316+
if not self.tags:
317+
self.tags = resolve_value_from_config(config_path=TRAINING_JOB_TAGS_PATH)
318+
319+
def _convert_keys_to_snake(self, config: dict) -> dict:
320+
"""Utility helper function that converts the keys of a dictionary into snake case"""
321+
return {
322+
to_snake_case(key): value
323+
for key, value in config.items()
324+
}
325+
326+
def _populate_intelligent_defaults_from_model_trainer_space(self):
327+
"""Function to populate all the possible default configs from Model Trainer Space"""
328+
329+
for configurable_attribute in self.CONFIGURABLE_ATTRIBUTES:
330+
if getattr(self, configurable_attribute) is None:
331+
default_config = resolve_value_from_config(
332+
config_path=_simple_path(SAGEMAKER,
333+
PYTHON_SDK,
334+
MODULES,
335+
MODEL_TRAINER,
336+
to_camel_case(configurable_attribute)),
337+
sagemaker_session=self.sagemaker_session)
338+
if default_config is not None:
339+
if configurable_attribute in self.SERIALIZABLE_CONFIG_ATTRIBUTES:
340+
default_config = (self.SERIALIZABLE_CONFIG_ATTRIBUTES
341+
.get(configurable_attribute)(**default_config)) # noqa
342+
setattr(self, configurable_attribute, default_config)
343+
207344
# Created Artifacts
208345
_latest_training_job: Optional[resources.TrainingJob] = PrivateAttr(default=None)
209346

@@ -374,6 +511,7 @@ def train(
374511
Whether to display the training container logs while training.
375512
Defaults to True.
376513
"""
514+
self._populate_intelligent_defaults()
377515
if input_data_config:
378516
self.input_data_config = input_data_config
379517

@@ -745,7 +883,8 @@ def with_debugger_settings(
745883
debug_hook_config (Optional[DebugHookConfig]):
746884
Configuration information for the Amazon SageMaker Debugger hook parameters,
747885
metric and tensor collections, and storage paths.
748-
To learn more see: https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-createtrainingjob-api.html
886+
To learn more see:
887+
https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-createtrainingjob-api.html
749888
debug_rule_configurations (Optional[List[DebugRuleConfiguration]]):
750889
Configuration information for Amazon SageMaker Debugger rules for debugging
751890
output ensors.

0 commit comments

Comments
 (0)