13
13
"""ImageSpec class module."""
14
14
from __future__ import absolute_import
15
15
16
+ import re
17
+ from enum import Enum
16
18
from typing import Optional
17
19
18
- from sagemaker import image_uris , Session
19
- from sagemaker .serverless import ServerlessInferenceConfig
20
- from sagemaker .training_compiler .config import TrainingCompilerConfig
20
+ from sagemaker import utils
21
+ from sagemaker .image_uris import _validate_version_and_set_if_needed , _version_for_config , \
22
+ _config_for_framework_and_scope , _validate_py_version_and_set_if_needed , _registry_from_region , ECR_URI_TEMPLATE , \
23
+ _get_latest_versions , _validate_instance_deprecation , _get_image_tag , _validate_arg
24
+ from packaging .version import Version
25
+
26
+ DEFAULT_TOLERATE_MODEL = False
27
+
28
+
29
+ class Framework (Enum ):
30
+ HUGGING_FACE = "huggingface"
31
+ HUGGING_FACE_NEURON = "huggingface-neuron"
32
+ HUGGING_FACE_NEURON_X = "huggingface-neuronx"
33
+ HUGGING_FACE_LLM = "huggingface-llm"
34
+ HUGGING_FACE_TEI_GPU = "huggingface-tei"
35
+ HUGGING_FACE_TEI_CPU = "huggingface-tei-cpu"
36
+ HUGGING_FACE_LLM_NEURONX = "huggingface-llm-neuronx"
37
+ HUGGING_FACE_TRAINING_COMPILER = "huggingface-training-compiler"
38
+ XGBOOST = "xgboost"
39
+ XG_BOOST_NEO = "xg-boost-neo"
40
+ SKLEARN = "sklearn"
41
+ PYTORCH = "pytorch"
42
+ PYTORCH_TRAINING_COMPILER = "pytorch-training-compiler"
43
+ DATA_WRANGLER = "data-wrangler"
44
+ STABILITYAI = "stabilityai"
45
+ SAGEMAKER_TRITONSERVER = "sagemaker-tritonserver"
46
+
47
+
48
+ class ImageScope (Enum ):
49
+ TRAINING = "training"
50
+ INFERENCE = "inference"
51
+ INFERENCE_GRAVITON = "inference-graviton"
52
+
53
+
54
+ class Processor (Enum ):
55
+ INF = "inf"
56
+ NEURON = "neuron"
57
+ GPU = "gpu"
58
+ CPU = "cpu"
59
+ TRN = "trn"
21
60
22
61
23
62
class ImageSpec :
24
63
"""ImageSpec class to get image URI for a specific framework version."""
25
64
26
- def __init__ (
27
- self ,
28
- framework_name : str ,
29
- version : str ,
30
- image_scope : Optional [str ] = None ,
31
- instance_type : Optional [str ] = None ,
32
- py_version : Optional [str ] = None ,
33
- region : Optional [str ] = "us-west-2" ,
34
- accelerator_type : Optional [str ] = None ,
35
- container_version : Optional [str ] = None ,
36
- distribution : Optional [dict ] = None ,
37
- base_framework_version : Optional [str ] = None ,
38
- training_compiler_config : Optional [TrainingCompilerConfig ] = None ,
39
- model_id : Optional [str ] = None ,
40
- model_version : Optional [str ] = None ,
41
- hub_arn : Optional [str ] = None ,
42
- tolerate_vulnerable_model : Optional [bool ] = False ,
43
- tolerate_deprecated_model : Optional [bool ] = False ,
44
- sdk_version : Optional [str ] = None ,
45
- inference_tool : Optional [str ] = None ,
46
- serverless_inference_config : Optional [ServerlessInferenceConfig ] = None ,
47
- config_name : Optional [str ] = None ,
48
- sagemaker_session : Optional [Session ] = None ,
49
- ):
50
- self .framework_name = framework_name
65
+ def __init__ (self ,
66
+ framework : Framework ,
67
+ processor : Optional [Processor ] = Processor .CPU ,
68
+ region : Optional [str ] = "us-west-2" ,
69
+ version = None ,
70
+ py_version = None ,
71
+ instance_type = None ,
72
+ accelerator_type = None ,
73
+ image_scope : ImageScope = ImageScope .TRAINING ,
74
+ container_version = None ,
75
+ distribution = None ,
76
+ base_framework_version = None ,
77
+ sdk_version = None ,
78
+ inference_tool = None ):
79
+ self .framework = framework
80
+ self .processor = processor
51
81
self .version = version
52
82
self .image_scope = image_scope
53
83
self .instance_type = instance_type
@@ -57,45 +87,175 @@ def __init__(
57
87
self .container_version = container_version
58
88
self .distribution = distribution
59
89
self .base_framework_version = base_framework_version
60
- self .training_compiler_config = training_compiler_config
61
- self .model_id = model_id
62
- self .model_version = model_version
63
- self .hub_arn = hub_arn
64
- self .tolerate_vulnerable_model = tolerate_vulnerable_model
65
- self .tolerate_deprecated_model = tolerate_deprecated_model
66
90
self .sdk_version = sdk_version
67
91
self .inference_tool = inference_tool
68
- self .serverless_inference_config = serverless_inference_config
69
- self .config_name = config_name
70
- self .sagemaker_session = sagemaker_session
71
-
72
- def get_image_uri (
73
- self , image_scope : Optional [str ] = None , instance_type : Optional [str ] = None
74
- ) -> str :
75
- """Get image URI for a specific framework version."""
76
-
77
- self .image_scope = image_scope or self .image_scope
78
- self .instance_type = instance_type or self .instance_type
79
- return image_uris .retrieve (
80
- framework = self .framework_name ,
81
- image_scope = self .image_scope ,
82
- instance_type = self .instance_type ,
83
- py_version = self .py_version ,
84
- region = self .region ,
85
- version = self .version ,
86
- accelerator_type = self .accelerator_type ,
87
- container_version = self .container_version ,
88
- distribution = self .distribution ,
89
- base_framework_version = self .base_framework_version ,
90
- training_compiler_config = self .training_compiler_config ,
91
- model_id = self .model_id ,
92
- model_version = self .model_version ,
93
- hub_arn = self .hub_arn ,
94
- tolerate_vulnerable_model = self .tolerate_vulnerable_model ,
95
- tolerate_deprecated_model = self .tolerate_deprecated_model ,
96
- sdk_version = self .sdk_version ,
97
- inference_tool = self .inference_tool ,
98
- serverless_inference_config = self .serverless_inference_config ,
99
- config_name = self .config_name ,
100
- sagemaker_session = self .sagemaker_session ,
101
- )
92
+
93
+ def update_image_spec (self , ** kwargs ):
94
+ for key , value in kwargs .items ():
95
+ if hasattr (self , key ):
96
+ setattr (self , key , value )
97
+
98
+ def retrieve (self ) -> str :
99
+ """Retrieves the ECR URI for the Docker image matching the given arguments.
100
+
101
+ Ideally this function should not be called directly, rather it should be called from the
102
+ fit() function inside framework estimator.
103
+
104
+ Args:
105
+ framework (Framework): The name of the framework or algorithm.
106
+ processor (Processor): The name of the processor (CPU, GPU, etc.).
107
+ region (str): The AWS region.
108
+ version (str): The framework or algorithm version. This is required if there is
109
+ more than one supported version for the given framework or algorithm.
110
+ py_version (str): The Python version. This is required if there is
111
+ more than one supported Python version for the given framework version.
112
+ instance_type (str): The SageMaker instance type. For supported types, see
113
+ https://aws.amazon.com/sagemaker/pricing. This is required if
114
+ there are different images for different processor types.
115
+ accelerator_type (str): Elastic Inference accelerator type. For more, see
116
+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
117
+ image_scope (str): The image type, i.e. what it is used for.
118
+ Valid values: "training", "inference", "inference_graviton", "eia".
119
+ If ``accelerator_type`` is set, ``image_scope`` is ignored.
120
+ container_version (str): the version of docker image.
121
+ Ideally the value of parameter should be created inside the framework.
122
+ For custom use, see the list of supported container versions:
123
+ https://github.com/aws/deep-learning-containers/blob/master/available_images.md
124
+ (default: None).
125
+ distribution (dict): A dictionary with information on how to run distributed training
126
+ sdk_version (str): the version of python-sdk that will be used in the image retrieval.
127
+ (default: None).
128
+ inference_tool (str): the tool that will be used to aid in the inference.
129
+ Valid values: "neuron, neuronx, None"
130
+ (default: None).
131
+
132
+ Returns:
133
+ str: The ECR URI for the corresponding SageMaker Docker image.
134
+
135
+ Raises:
136
+ NotImplementedError: If the scope is not supported.
137
+ ValueError: If the combination of arguments specified is not supported or
138
+ any PipelineVariable object is passed in.
139
+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
140
+ known security vulnerabilities.
141
+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
142
+ """
143
+ config = _config_for_framework_and_scope (self .framework .value ,
144
+ self .image_scope .value ,
145
+ self .accelerator_type )
146
+
147
+ original_version = self .version
148
+ try :
149
+ version = _validate_version_and_set_if_needed (self .version , config , self .framework .value )
150
+ except ValueError :
151
+ version = None
152
+ if not version :
153
+ version = self ._fetch_latest_version_from_config (config )
154
+
155
+ version_config = config ["versions" ][_version_for_config (version , config )]
156
+
157
+ if "huggingface" in self .framework .value :
158
+ if version_config .get ("version_aliases" ):
159
+ full_base_framework_version = version_config ["version_aliases" ].get (
160
+ self .base_framework_version , self .base_framework_version
161
+ )
162
+ _validate_arg (full_base_framework_version , list (version_config .keys ()), "base framework" )
163
+ version_config = version_config .get (full_base_framework_version )
164
+
165
+ self .py_version = _validate_py_version_and_set_if_needed (self .py_version ,
166
+ version_config ,
167
+ self .framework .value )
168
+ version_config = version_config .get (self .py_version ) or version_config
169
+
170
+ registry = _registry_from_region (self .region , version_config ["registries" ])
171
+ endpoint_data = utils ._botocore_resolver ().construct_endpoint ("ecr" , self .region )
172
+ if self .region == "il-central-1" and not endpoint_data :
173
+ endpoint_data = {"hostname" : "ecr.{}.amazonaws.com" .format (self .region )}
174
+ hostname = endpoint_data ["hostname" ]
175
+
176
+ repo = version_config ["repository" ]
177
+
178
+ # if container version is available in .json file, utilize that
179
+ if version_config .get ("container_version" ):
180
+ self .container_version = version_config ["container_version" ][self .processor .value ]
181
+
182
+ # Append sdk version in case of trainium instances
183
+ if repo in ["pytorch-training-neuron" ]:
184
+ if not self .sdk_version :
185
+ sdk_version = _get_latest_versions (version_config ["sdk_versions" ])
186
+ self .container_version = self .sdk_version + "-" + self .container_version
187
+
188
+ if self .framework == Framework .HUGGING_FACE :
189
+ pt_or_tf_version = (
190
+ re .compile ("^(pytorch|tensorflow)(.*)$" ).match (self .base_framework_version ).group (2 )
191
+ )
192
+ _version = original_version
193
+
194
+ if repo in [
195
+ "huggingface-pytorch-trcomp-training" ,
196
+ "huggingface-tensorflow-trcomp-training" ,
197
+ ]:
198
+ _version = version
199
+ if repo in [
200
+ "huggingface-pytorch-inference-neuron" ,
201
+ "huggingface-pytorch-inference-neuronx" ,
202
+ ]:
203
+ if not sdk_version :
204
+ self .sdk_version = _get_latest_versions (version_config ["sdk_versions" ])
205
+ self .container_version = self .sdk_version + "-" + self .container_version
206
+ if config .get ("version_aliases" ).get (original_version ):
207
+ _version = config .get ("version_aliases" )[original_version ]
208
+ if (
209
+ config .get ("versions" , {})
210
+ .get (_version , {})
211
+ .get ("version_aliases" , {})
212
+ .get (self .base_framework_version , {})
213
+ ):
214
+ _base_framework_version = config .get ("versions" )[_version ]["version_aliases" ][
215
+ self .base_framework_version
216
+ ]
217
+ pt_or_tf_version = (
218
+ re .compile ("^(pytorch|tensorflow)(.*)$" ).match (_base_framework_version ).group (2 )
219
+ )
220
+
221
+ tag_prefix = f"{ pt_or_tf_version } -transformers{ _version } "
222
+ else :
223
+ tag_prefix = version_config .get ("tag_prefix" , version )
224
+
225
+ if repo == f"{ self .framework .value } -inference-graviton" :
226
+ self .container_version = f"{ self .container_version } -sagemaker"
227
+ _validate_instance_deprecation (self .framework ,
228
+ self .instance_type ,
229
+ version )
230
+
231
+ tag = _get_image_tag (
232
+ self .container_version ,
233
+ self .distribution ,
234
+ self .image_scope .value ,
235
+ self .framework ,
236
+ self .inference_tool ,
237
+ self .instance_type ,
238
+ self .processor .value ,
239
+ self .py_version ,
240
+ tag_prefix ,
241
+ version )
242
+
243
+ if tag :
244
+ repo += ":{}" .format (tag )
245
+
246
+ return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo )
247
+
248
+ def _fetch_latest_version_from_config (self ,
249
+ framework_config : dict ) -> str :
250
+ if self .image_scope .value in framework_config :
251
+ if image_scope_config := framework_config [self .image_scope .value ]:
252
+ if version_aliases := image_scope_config ["version_aliases" ]:
253
+ if latest_version := version_aliases ["latest" ]:
254
+ return latest_version
255
+ versions = list (framework_config ["versions" ].keys ())
256
+ top_version = versions [0 ]
257
+ bottom_version = versions [- 1 ]
258
+
259
+ if Version (top_version ) >= Version (bottom_version ):
260
+ return top_version
261
+ return bottom_version
0 commit comments