Skip to content

Commit a9dd628

Browse files
beniericpintaoz-aws
authored andcommitted
Update kandinsky in ModelTrainer and allow setting requirements (#1587)
1 parent fa02963 commit a9dd628

File tree

5 files changed

+259
-31
lines changed

5 files changed

+259
-31
lines changed

src/sagemaker/modules/train/model_trainer.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
EXECUTE_BASIC_SCRIPT_DRIVER,
8888
)
8989
from sagemaker.modules import logger
90-
from sagemaker.modules.train.sm_recipes.utils import get_args_from_recipe, _determine_device_type
90+
from sagemaker.modules.train.sm_recipes.utils import _get_args_from_recipe, _determine_device_type
9191

9292

9393
class Mode(Enum):
@@ -154,7 +154,7 @@ class ModelTrainer(BaseModel):
154154
see: https://docs.aws.amazon.com/sagemaker/latest/dg-ecr-paths/sagemaker-algo-docker-registry-paths
155155
training_image_config (Optional[TrainingImageConfig]):
156156
Training image Config. This is the configuration to use an image from a private
157-
Docker registry for a traininob.
157+
Docker registry for a training job.
158158
output_data_config (Optional[OutputDataConfig]):
159159
The output data configuration. This is used to specify the output data location
160160
for the training job.
@@ -481,10 +481,6 @@ def train(
481481
)
482482
self._latest_training_job = training_job
483483

484-
# Clean up the temporary directory if it exists
485-
if self._temp_recipe_train_dir is not None:
486-
self._temp_recipe_train_dir.cleanup()
487-
488484
if wait:
489485
training_job.wait(logs=logs)
490486
if logs and not wait:
@@ -816,11 +812,18 @@ def from_recipe(
816812
training_recipe: str,
817813
compute: Compute,
818814
recipe_overrides: Optional[Dict[str, Any]] = None,
815+
requirements: Optional[str] = None,
819816
training_image: Optional[str] = None,
817+
training_image_config: Optional[TrainingImageConfig] = None,
818+
output_data_config: Optional[OutputDataConfig] = None,
819+
input_data_config: Optional[List[Union[Channel, InputData]]] = None,
820+
checkpoint_config: Optional[CheckpointConfig] = None,
821+
training_input_mode: Optional[str] = "File",
822+
environment: Optional[Dict[str, str]] = None,
823+
tags: Optional[List[Tag]] = None,
820824
session: Optional[Session] = None,
821825
role: Optional[str] = None,
822826
base_job_name: Optional[str] = None,
823-
**kwargs,
824827
) -> "ModelTrainer":
825828
"""Create a ModelTrainer from a training recipe.
826829
@@ -833,9 +836,33 @@ def from_recipe(
833836
the training job. If not specified, will default to 1 instance of ml.m5.xlarge.
834837
recipe_overrides (Optional[Dict[str, Any]]):
835838
The recipe overrides. This is used to override the default recipe parameters.
839+
requirements (Optional[str]):
840+
The path to a requirements file to install in the training job container.
836841
training_image (Optional[str]):
837842
The training image URI to use for the training job container. If not specified,
838843
the training image will be determined from the recipe.
844+
training_image_config (Optional[TrainingImageConfig]):
845+
Training image Config. This is the configuration to use an image from a private
846+
Docker registry for a training job.
847+
output_data_config (Optional[OutputDataConfig]):
848+
The output data configuration. This is used to specify the output data location
849+
for the training job.
850+
If not specified, will default to `s3://<default_bucket>/<base_job_name>/output/`.
851+
input_data_config (Optional[List[Union[Channel, InputData]]]):
852+
The input data config for the training job.
853+
Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
854+
string, local file path string, S3DataSource object, or FileSystemDataSource object.
855+
checkpoint_config (Optional[CheckpointConfig]):
856+
Contains information about the output location for managed spot training checkpoint
857+
data.
858+
training_input_mode (Optional[str]):
859+
The input mode for the training job. Valid values are "Pipe", "File", "FastFile".
860+
Defaults to "File".
861+
environment (Optional[Dict[str, str]]):
862+
The environment variables for the training job.
863+
tags (Optional[List[Tag]]):
864+
An array of key-value pairs. You can use tags to categorize your AWS resources
865+
in different ways, for example, by purpose, owner, or environment.
839866
session (Optional[Session]):
840867
The SageMaker session.
841868
If not specified, a new session will be created.
@@ -846,9 +873,6 @@ def from_recipe(
846873
The base name for the training job.
847874
If not specified, a default name will be generated using the algorithm name
848875
or training image.
849-
kwargs:
850-
Additional keyword arguments to pass to the ModelTrainer constructor.
851-
852876
"""
853877
if compute.instance_type is None:
854878
raise ValueError(
@@ -865,20 +889,37 @@ def from_recipe(
865889
session = Session()
866890
logger.warning("Session not provided. Using default Session.")
867891
if role is None:
868-
role = get_execution_role()
892+
role = get_execution_role(sagemaker_session=session)
869893
logger.warning(f"Role not provided. Using default role:\n{role}")
870894

871-
model_trainer_args, recipe_train_dir = get_args_from_recipe(
895+
# The training recipe is used to prepare the following args:
896+
# - source_code
897+
# - training_image
898+
# - distributed_runner
899+
# - compute
900+
# - hyperparameters
901+
model_trainer_args, recipe_train_dir = _get_args_from_recipe(
872902
training_recipe=training_recipe,
873903
recipe_overrides=recipe_overrides,
904+
requirements=requirements,
874905
compute=compute,
875-
session=session,
906+
region_name=session.boto_region_name,
876907
)
877908
if training_image is not None:
878909
model_trainer_args["training_image"] = training_image
879910

880911
model_trainer = cls(
881-
session=session, role=role, base_job_name=base_job_name, **model_trainer_args, **kwargs
912+
session=session,
913+
role=role,
914+
base_job_name=base_job_name,
915+
training_image_config=training_image_config,
916+
output_data_config=output_data_config,
917+
input_data_config=input_data_config,
918+
checkpoint_config=checkpoint_config,
919+
training_input_mode=training_input_mode,
920+
environment=environment,
921+
tags=tags,
922+
**model_trainer_args,
882923
)
883924

884925
model_trainer._temp_recipe_train_dir = recipe_train_dir

src/sagemaker/modules/train/sm_recipes/training_recipes.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
2-
"adapter_repo": "[email protected]-adapter:benieric/private-sagemaker-hyperpod-training-adapter-for-nemo-staging.git",
3-
"launcher_repo": "[email protected]-launcher:benieric/private-sagemaker-hyperpod-recipes-staging.git",
2+
"adapter_repo": "[email protected]:aws/private-sagemaker-hyperpod-training-adapter-for-nemo-staging.git",
3+
"launcher_repo": "[email protected]:aws/private-sagemaker-hyperpod-recipes-staging.git",
44
"neuron_dist_repo": "https://github.com/aws-neuron/neuronx-distributed-training.git",
55
"gpu_image" : {
66
"framework": "pytorch-smp",

src/sagemaker/modules/train/sm_recipes/utils.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from omegaconf import OmegaConf, dictconfig
2626

2727
from sagemaker.image_uris import retrieve
28-
from sagemaker import Session
2928

3029
from sagemaker.modules import logger
3130
from sagemaker.modules.utils import _run_clone_command_silent
@@ -66,7 +65,7 @@ def _load_recipes_cfg() -> str:
6665

6766
def _load_base_recipe(
6867
training_recipe: str,
69-
recipe_overrides: Optional[Dict[str, Any]],
68+
recipe_overrides: Optional[Dict[str, Any]] = None,
7069
training_recipes_cfg: Optional[Dict[str, Any]] = None,
7170
) -> Dict[str, Any]:
7271
"""Load recipe and apply overrides."""
@@ -195,7 +194,6 @@ def _configure_trainium_args(
195194

196195
_run_clone_command_silent(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir.name)
197196

198-
# Set SourceCodeConfig
199197
source_code.source_dir = os.path.join(recipe_train_dir.name, "examples")
200198
source_code.entry_script = "training_orchestrator.py"
201199
neuron_image_cfg = training_recipes_cfg.get("neuron_image")
@@ -220,11 +218,12 @@ def _configure_trainium_args(
220218
return args
221219

222220

223-
def get_args_from_recipe(
221+
def _get_args_from_recipe(
224222
training_recipe: str,
225223
compute: Compute,
226-
session: Session,
224+
region_name: str,
227225
recipe_overrides: Optional[Dict[str, Any]],
226+
requirements: Optional[str],
228227
) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]:
229228
"""Get arguments for ModelTrainer from a training recipe.
230229
@@ -233,7 +232,7 @@ def get_args_from_recipe(
233232
{
234233
"source_code": SourceCode,
235234
"training_image": str,
236-
"distributed_runner": Dict[str, Any],
235+
"distributed_runner": DistributedRunner,
237236
"compute": Compute,
238237
"hyperparameters": Dict[str, Any],
239238
}
@@ -244,15 +243,16 @@ def get_args_from_recipe(
244243
Name of the training recipe or path to the recipe file.
245244
compute (Compute):
246245
Compute configuration for training.
247-
session (Session):
248-
Session object for training.
246+
region_name (str):
247+
Name of the AWS region.
249248
recipe_overrides (Optional[Dict[str, Any]]):
250249
Overrides for the training recipe.
250+
requirements (Optional[str]):
251+
Path to the requirements file.
251252
"""
252253
if compute.instance_type is None:
253254
raise ValueError("Must set `instance_type` in compute when using training recipes.")
254255

255-
region_name = session.boto_region_name
256256
training_recipes_cfg = _load_recipes_cfg()
257257
recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg)
258258

@@ -262,18 +262,20 @@ def get_args_from_recipe(
262262
# Set instance_count
263263
if compute.instance_count and "num_nodes" in recipe["trainer"]:
264264
logger.warning(
265-
"Using instance_count in compute to set number "
266-
" of nodes. Ignoring trainer -> num_nodes in recipe."
265+
f"Using Compute to set instance_count:\n{compute}."
266+
"\nIgnoring trainer -> num_nodes in recipe."
267267
)
268268
if compute.instance_count is None:
269269
if "num_nodes" not in recipe["trainer"]:
270270
raise ValueError(
271-
"Must set either instance_count argument for estimator or"
272-
"set trainer -> num_nodes in recipe."
271+
"Must provide Compute with instance_count or" " set trainer -> num_nodes in recipe."
273272
)
274273
compute.instance_count = recipe["trainer"]["num_nodes"]
275274

276-
# Get Training Image, SourceCodeConfig, and DistributionConfig args
275+
if requirements and not os.path.isfile(requirements):
276+
raise ValueError(f"Recipe requirements file {requirements} not found.")
277+
278+
# Get Training Image, SourceCode, and DistributedRunner args
277279
device_type = _determine_device_type(compute.instance_type)
278280
recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_")
279281
if device_type == "gpu":
@@ -299,7 +301,12 @@ def get_args_from_recipe(
299301
config=final_recipe, f=os.path.join(args["source_code"].source_dir, "recipe.yaml")
300302
)
301303

302-
# Update args with compute_config and hyperparameters
304+
# If recipe_requirements is provided, copy it to source_dir
305+
if requirements:
306+
shutil.copy(requirements, args["source_code"].source_dir)
307+
args["source_code"].requirements = os.path.basename(requirements)
308+
309+
# Update args with compute and hyperparameters
303310
args.update(
304311
{
305312
"compute": compute,

tests/unit/sagemaker/modules/train/sm_recipes/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)