8787 EXECUTE_BASIC_SCRIPT_DRIVER ,
8888)
8989from sagemaker .modules import logger
90- from sagemaker .modules .train .sm_recipes .utils import get_args_from_recipe , _determine_device_type
90+ from sagemaker .modules .train .sm_recipes .utils import _get_args_from_recipe , _determine_device_type
9191
9292
9393class Mode (Enum ):
@@ -154,7 +154,7 @@ class ModelTrainer(BaseModel):
154154 see: https://docs.aws.amazon.com/sagemaker/latest/dg-ecr-paths/sagemaker-algo-docker-registry-paths
155155 training_image_config (Optional[TrainingImageConfig]):
156156 Training image Config. This is the configuration to use an image from a private
157- Docker registry for a traininob .
157+ Docker registry for a training job .
158158 output_data_config (Optional[OutputDataConfig]):
159159 The output data configuration. This is used to specify the output data location
160160 for the training job.
@@ -481,10 +481,6 @@ def train(
481481 )
482482 self ._latest_training_job = training_job
483483
484- # Clean up the temporary directory if it exists
485- if self ._temp_recipe_train_dir is not None :
486- self ._temp_recipe_train_dir .cleanup ()
487-
488484 if wait :
489485 training_job .wait (logs = logs )
490486 if logs and not wait :
@@ -816,11 +812,18 @@ def from_recipe(
816812 training_recipe : str ,
817813 compute : Compute ,
818814 recipe_overrides : Optional [Dict [str , Any ]] = None ,
815+ requirements : Optional [str ] = None ,
819816 training_image : Optional [str ] = None ,
817+ training_image_config : Optional [TrainingImageConfig ] = None ,
818+ output_data_config : Optional [OutputDataConfig ] = None ,
819+ input_data_config : Optional [List [Union [Channel , InputData ]]] = None ,
820+ checkpoint_config : Optional [CheckpointConfig ] = None ,
821+ training_input_mode : Optional [str ] = "File" ,
822+ environment : Optional [Dict [str , str ]] = None ,
823+ tags : Optional [List [Tag ]] = None ,
820824 session : Optional [Session ] = None ,
821825 role : Optional [str ] = None ,
822826 base_job_name : Optional [str ] = None ,
823- ** kwargs ,
824827 ) -> "ModelTrainer" :
825828 """Create a ModelTrainer from a training recipe.
826829
@@ -833,9 +836,33 @@ def from_recipe(
833836 the training job. If not specified, will default to 1 instance of ml.m5.xlarge.
834837 recipe_overrides (Optional[Dict[str, Any]]):
835838 The recipe overrides. This is used to override the default recipe parameters.
839+ requirements (Optional[str]):
840+ The path to a requirements file to install in the training job container.
836841 training_image (Optional[str]):
837842 The training image URI to use for the training job container. If not specified,
838843 the training image will be determined from the recipe.
844+ training_image_config (Optional[TrainingImageConfig]):
845+ Training image Config. This is the configuration to use an image from a private
846+ Docker registry for a training job.
847+ output_data_config (Optional[OutputDataConfig]):
848+ The output data configuration. This is used to specify the output data location
849+ for the training job.
850+ If not specified, will default to `s3://<default_bucket>/<base_job_name>/output/`.
851+ input_data_config (Optional[List[Union[Channel, InputData]]]):
852+ The input data config for the training job.
853+ Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
854+ string, local file path string, S3DataSource object, or FileSystemDataSource object.
855+ checkpoint_config (Optional[CheckpointConfig]):
856+ Contains information about the output location for managed spot training checkpoint
857+ data.
858+ training_input_mode (Optional[str]):
859+ The input mode for the training job. Valid values are "Pipe", "File", "FastFile".
860+ Defaults to "File".
861+ environment (Optional[Dict[str, str]]):
862+ The environment variables for the training job.
863+ tags (Optional[List[Tag]]):
864+ An array of key-value pairs. You can use tags to categorize your AWS resources
865+ in different ways, for example, by purpose, owner, or environment.
839866 session (Optional[Session]):
840867 The SageMaker session.
841868 If not specified, a new session will be created.
@@ -846,9 +873,6 @@ def from_recipe(
846873 The base name for the training job.
847874 If not specified, a default name will be generated using the algorithm name
848875 or training image.
849- kwargs:
850- Additional keyword arguments to pass to the ModelTrainer constructor.
851-
852876 """
853877 if compute .instance_type is None :
854878 raise ValueError (
@@ -865,20 +889,37 @@ def from_recipe(
865889 session = Session ()
866890 logger .warning ("Session not provided. Using default Session." )
867891 if role is None :
868- role = get_execution_role ()
892+ role = get_execution_role (sagemaker_session = session )
869893 logger .warning (f"Role not provided. Using default role:\n { role } " )
870894
871- model_trainer_args , recipe_train_dir = get_args_from_recipe (
895+ # The training recipe is used to prepare the following args:
896+ # - source_code
897+ # - training_image
898+ # - distributed_runner
899+ # - compute
900+ # - hyperparameters
901+ model_trainer_args , recipe_train_dir = _get_args_from_recipe (
872902 training_recipe = training_recipe ,
873903 recipe_overrides = recipe_overrides ,
904+ requirements = requirements ,
874905 compute = compute ,
875- session = session ,
906+ region_name = session . boto_region_name ,
876907 )
877908 if training_image is not None :
878909 model_trainer_args ["training_image" ] = training_image
879910
880911 model_trainer = cls (
881- session = session , role = role , base_job_name = base_job_name , ** model_trainer_args , ** kwargs
912+ session = session ,
913+ role = role ,
914+ base_job_name = base_job_name ,
915+ training_image_config = training_image_config ,
916+ output_data_config = output_data_config ,
917+ input_data_config = input_data_config ,
918+ checkpoint_config = checkpoint_config ,
919+ training_input_mode = training_input_mode ,
920+ environment = environment ,
921+ tags = tags ,
922+ ** model_trainer_args ,
882923 )
883924
884925 model_trainer ._temp_recipe_train_dir = recipe_train_dir
0 commit comments