87
87
EXECUTE_BASIC_SCRIPT_DRIVER ,
88
88
)
89
89
from sagemaker .modules import logger
90
- from sagemaker .modules .train .sm_recipes .utils import get_args_from_recipe , _determine_device_type
90
+ from sagemaker .modules .train .sm_recipes .utils import _get_args_from_recipe , _determine_device_type
91
91
92
92
93
93
class Mode (Enum ):
@@ -154,7 +154,7 @@ class ModelTrainer(BaseModel):
154
154
see: https://docs.aws.amazon.com/sagemaker/latest/dg-ecr-paths/sagemaker-algo-docker-registry-paths
155
155
training_image_config (Optional[TrainingImageConfig]):
156
156
Training image Config. This is the configuration to use an image from a private
157
- Docker registry for a traininob .
157
+ Docker registry for a training job .
158
158
output_data_config (Optional[OutputDataConfig]):
159
159
The output data configuration. This is used to specify the output data location
160
160
for the training job.
@@ -481,10 +481,6 @@ def train(
481
481
)
482
482
self ._latest_training_job = training_job
483
483
484
- # Clean up the temporary directory if it exists
485
- if self ._temp_recipe_train_dir is not None :
486
- self ._temp_recipe_train_dir .cleanup ()
487
-
488
484
if wait :
489
485
training_job .wait (logs = logs )
490
486
if logs and not wait :
@@ -816,11 +812,18 @@ def from_recipe(
816
812
training_recipe : str ,
817
813
compute : Compute ,
818
814
recipe_overrides : Optional [Dict [str , Any ]] = None ,
815
+ requirements : Optional [str ] = None ,
819
816
training_image : Optional [str ] = None ,
817
+ training_image_config : Optional [TrainingImageConfig ] = None ,
818
+ output_data_config : Optional [OutputDataConfig ] = None ,
819
+ input_data_config : Optional [List [Union [Channel , InputData ]]] = None ,
820
+ checkpoint_config : Optional [CheckpointConfig ] = None ,
821
+ training_input_mode : Optional [str ] = "File" ,
822
+ environment : Optional [Dict [str , str ]] = None ,
823
+ tags : Optional [List [Tag ]] = None ,
820
824
session : Optional [Session ] = None ,
821
825
role : Optional [str ] = None ,
822
826
base_job_name : Optional [str ] = None ,
823
- ** kwargs ,
824
827
) -> "ModelTrainer" :
825
828
"""Create a ModelTrainer from a training recipe.
826
829
@@ -833,9 +836,33 @@ def from_recipe(
833
836
the training job. If not specified, will default to 1 instance of ml.m5.xlarge.
834
837
recipe_overrides (Optional[Dict[str, Any]]):
835
838
The recipe overrides. This is used to override the default recipe parameters.
839
+ requirements (Optional[str]):
840
+ The path to a requirements file to install in the training job container.
836
841
training_image (Optional[str]):
837
842
The training image URI to use for the training job container. If not specified,
838
843
the training image will be determined from the recipe.
844
+ training_image_config (Optional[TrainingImageConfig]):
845
+ Training image Config. This is the configuration to use an image from a private
846
+ Docker registry for a training job.
847
+ output_data_config (Optional[OutputDataConfig]):
848
+ The output data configuration. This is used to specify the output data location
849
+ for the training job.
850
+ If not specified, will default to `s3://<default_bucket>/<base_job_name>/output/`.
851
+ input_data_config (Optional[List[Union[Channel, InputData]]]):
852
+ The input data config for the training job.
853
+ Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
854
+ string, local file path string, S3DataSource object, or FileSystemDataSource object.
855
+ checkpoint_config (Optional[CheckpointConfig]):
856
+ Contains information about the output location for managed spot training checkpoint
857
+ data.
858
+ training_input_mode (Optional[str]):
859
+ The input mode for the training job. Valid values are "Pipe", "File", "FastFile".
860
+ Defaults to "File".
861
+ environment (Optional[Dict[str, str]]):
862
+ The environment variables for the training job.
863
+ tags (Optional[List[Tag]]):
864
+ An array of key-value pairs. You can use tags to categorize your AWS resources
865
+ in different ways, for example, by purpose, owner, or environment.
839
866
session (Optional[Session]):
840
867
The SageMaker session.
841
868
If not specified, a new session will be created.
@@ -846,9 +873,6 @@ def from_recipe(
846
873
The base name for the training job.
847
874
If not specified, a default name will be generated using the algorithm name
848
875
or training image.
849
- kwargs:
850
- Additional keyword arguments to pass to the ModelTrainer constructor.
851
-
852
876
"""
853
877
if compute .instance_type is None :
854
878
raise ValueError (
@@ -865,20 +889,37 @@ def from_recipe(
865
889
session = Session ()
866
890
logger .warning ("Session not provided. Using default Session." )
867
891
if role is None :
868
- role = get_execution_role ()
892
+ role = get_execution_role (sagemaker_session = session )
869
893
logger .warning (f"Role not provided. Using default role:\n { role } " )
870
894
871
- model_trainer_args , recipe_train_dir = get_args_from_recipe (
895
+ # The training recipe is used to prepare the following args:
896
+ # - source_code
897
+ # - training_image
898
+ # - distributed_runner
899
+ # - compute
900
+ # - hyperparameters
901
+ model_trainer_args , recipe_train_dir = _get_args_from_recipe (
872
902
training_recipe = training_recipe ,
873
903
recipe_overrides = recipe_overrides ,
904
+ requirements = requirements ,
874
905
compute = compute ,
875
- session = session ,
906
+ region_name = session . boto_region_name ,
876
907
)
877
908
if training_image is not None :
878
909
model_trainer_args ["training_image" ] = training_image
879
910
880
911
model_trainer = cls (
881
- session = session , role = role , base_job_name = base_job_name , ** model_trainer_args , ** kwargs
912
+ session = session ,
913
+ role = role ,
914
+ base_job_name = base_job_name ,
915
+ training_image_config = training_image_config ,
916
+ output_data_config = output_data_config ,
917
+ input_data_config = input_data_config ,
918
+ checkpoint_config = checkpoint_config ,
919
+ training_input_mode = training_input_mode ,
920
+ environment = environment ,
921
+ tags = tags ,
922
+ ** model_trainer_args ,
882
923
)
883
924
884
925
model_trainer ._temp_recipe_train_dir = recipe_train_dir
0 commit comments