Skip to content

Commit fba3285

Browse files
nargokulpintaoz-aws
authored andcommitted
Image Spec refactoring and updates (#1525)
* Image Spec refactoring and updates * Unit tests and update function for Image Spec * Fix hugging face test * Fix Tests
1 parent 9774f9e commit fba3285

File tree

3 files changed

+270
-68
lines changed

3 files changed

+270
-68
lines changed

src/sagemaker/modules/image_spec.py

Lines changed: 228 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -13,41 +13,71 @@
1313
"""ImageSpec class module."""
1414
from __future__ import absolute_import
1515

16+
import re
17+
from enum import Enum
1618
from typing import Optional
1719

18-
from sagemaker import image_uris, Session
19-
from sagemaker.serverless import ServerlessInferenceConfig
20-
from sagemaker.training_compiler.config import TrainingCompilerConfig
20+
from sagemaker import utils
21+
from sagemaker.image_uris import _validate_version_and_set_if_needed, _version_for_config, \
22+
_config_for_framework_and_scope, _validate_py_version_and_set_if_needed, _registry_from_region, ECR_URI_TEMPLATE, \
23+
_get_latest_versions, _validate_instance_deprecation, _get_image_tag, _validate_arg
24+
from packaging.version import Version
25+
26+
DEFAULT_TOLERATE_MODEL = False
27+
28+
29+
class Framework(Enum):
30+
HUGGING_FACE = "huggingface"
31+
HUGGING_FACE_NEURON = "huggingface-neuron"
32+
HUGGING_FACE_NEURON_X = "huggingface-neuronx"
33+
HUGGING_FACE_LLM = "huggingface-llm"
34+
HUGGING_FACE_TEI_GPU = "huggingface-tei"
35+
HUGGING_FACE_TEI_CPU = "huggingface-tei-cpu"
36+
HUGGING_FACE_LLM_NEURONX = "huggingface-llm-neuronx"
37+
HUGGING_FACE_TRAINING_COMPILER = "huggingface-training-compiler"
38+
XGBOOST = "xgboost"
39+
XG_BOOST_NEO = "xg-boost-neo"
40+
SKLEARN = "sklearn"
41+
PYTORCH = "pytorch"
42+
PYTORCH_TRAINING_COMPILER = "pytorch-training-compiler"
43+
DATA_WRANGLER = "data-wrangler"
44+
STABILITYAI = "stabilityai"
45+
SAGEMAKER_TRITONSERVER = "sagemaker-tritonserver"
46+
47+
48+
class ImageScope(Enum):
49+
TRAINING = "training"
50+
INFERENCE = "inference"
51+
INFERENCE_GRAVITON = "inference-graviton"
52+
53+
54+
class Processor(Enum):
55+
INF = "inf"
56+
NEURON = "neuron"
57+
GPU = "gpu"
58+
CPU = "cpu"
59+
TRN = "trn"
2160

2261

2362
class ImageSpec:
2463
"""ImageSpec class to get image URI for a specific framework version."""
2564

26-
def __init__(
27-
self,
28-
framework_name: str,
29-
version: str,
30-
image_scope: Optional[str] = None,
31-
instance_type: Optional[str] = None,
32-
py_version: Optional[str] = None,
33-
region: Optional[str] = "us-west-2",
34-
accelerator_type: Optional[str] = None,
35-
container_version: Optional[str] = None,
36-
distribution: Optional[dict] = None,
37-
base_framework_version: Optional[str] = None,
38-
training_compiler_config: Optional[TrainingCompilerConfig] = None,
39-
model_id: Optional[str] = None,
40-
model_version: Optional[str] = None,
41-
hub_arn: Optional[str] = None,
42-
tolerate_vulnerable_model: Optional[bool] = False,
43-
tolerate_deprecated_model: Optional[bool] = False,
44-
sdk_version: Optional[str] = None,
45-
inference_tool: Optional[str] = None,
46-
serverless_inference_config: Optional[ServerlessInferenceConfig] = None,
47-
config_name: Optional[str] = None,
48-
sagemaker_session: Optional[Session] = None,
49-
):
50-
self.framework_name = framework_name
65+
def __init__(self,
66+
framework: Framework,
67+
processor: Optional[Processor] = Processor.CPU,
68+
region: Optional[str] = "us-west-2",
69+
version=None,
70+
py_version=None,
71+
instance_type=None,
72+
accelerator_type=None,
73+
image_scope: ImageScope = ImageScope.TRAINING,
74+
container_version=None,
75+
distribution=None,
76+
base_framework_version=None,
77+
sdk_version=None,
78+
inference_tool=None):
79+
self.framework = framework
80+
self.processor = processor
5181
self.version = version
5282
self.image_scope = image_scope
5383
self.instance_type = instance_type
@@ -57,45 +87,175 @@ def __init__(
5787
self.container_version = container_version
5888
self.distribution = distribution
5989
self.base_framework_version = base_framework_version
60-
self.training_compiler_config = training_compiler_config
61-
self.model_id = model_id
62-
self.model_version = model_version
63-
self.hub_arn = hub_arn
64-
self.tolerate_vulnerable_model = tolerate_vulnerable_model
65-
self.tolerate_deprecated_model = tolerate_deprecated_model
6690
self.sdk_version = sdk_version
6791
self.inference_tool = inference_tool
68-
self.serverless_inference_config = serverless_inference_config
69-
self.config_name = config_name
70-
self.sagemaker_session = sagemaker_session
71-
72-
def get_image_uri(
73-
self, image_scope: Optional[str] = None, instance_type: Optional[str] = None
74-
) -> str:
75-
"""Get image URI for a specific framework version."""
76-
77-
self.image_scope = image_scope or self.image_scope
78-
self.instance_type = instance_type or self.instance_type
79-
return image_uris.retrieve(
80-
framework=self.framework_name,
81-
image_scope=self.image_scope,
82-
instance_type=self.instance_type,
83-
py_version=self.py_version,
84-
region=self.region,
85-
version=self.version,
86-
accelerator_type=self.accelerator_type,
87-
container_version=self.container_version,
88-
distribution=self.distribution,
89-
base_framework_version=self.base_framework_version,
90-
training_compiler_config=self.training_compiler_config,
91-
model_id=self.model_id,
92-
model_version=self.model_version,
93-
hub_arn=self.hub_arn,
94-
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
95-
tolerate_deprecated_model=self.tolerate_deprecated_model,
96-
sdk_version=self.sdk_version,
97-
inference_tool=self.inference_tool,
98-
serverless_inference_config=self.serverless_inference_config,
99-
config_name=self.config_name,
100-
sagemaker_session=self.sagemaker_session,
101-
)
92+
93+
def update_image_spec(self, **kwargs):
94+
for key, value in kwargs.items():
95+
if hasattr(self, key):
96+
setattr(self, key, value)
97+
98+
def retrieve(self) -> str:
99+
"""Retrieves the ECR URI for the Docker image matching the given arguments.
100+
101+
Ideally this function should not be called directly, rather it should be called from the
102+
fit() function inside framework estimator.
103+
104+
Args:
105+
framework (Framework): The name of the framework or algorithm.
106+
processor (Processor): The name of the processor (CPU, GPU, etc.).
107+
region (str): The AWS region.
108+
version (str): The framework or algorithm version. This is required if there is
109+
more than one supported version for the given framework or algorithm.
110+
py_version (str): The Python version. This is required if there is
111+
more than one supported Python version for the given framework version.
112+
instance_type (str): The SageMaker instance type. For supported types, see
113+
https://aws.amazon.com/sagemaker/pricing. This is required if
114+
there are different images for different processor types.
115+
accelerator_type (str): Elastic Inference accelerator type. For more, see
116+
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
117+
image_scope (str): The image type, i.e. what it is used for.
118+
Valid values: "training", "inference", "inference_graviton", "eia".
119+
If ``accelerator_type`` is set, ``image_scope`` is ignored.
120+
container_version (str): the version of docker image.
121+
Ideally the value of parameter should be created inside the framework.
122+
For custom use, see the list of supported container versions:
123+
https://github.com/aws/deep-learning-containers/blob/master/available_images.md
124+
(default: None).
125+
distribution (dict): A dictionary with information on how to run distributed training
126+
sdk_version (str): the version of python-sdk that will be used in the image retrieval.
127+
(default: None).
128+
inference_tool (str): the tool that will be used to aid in the inference.
129+
Valid values: "neuron, neuronx, None"
130+
(default: None).
131+
132+
Returns:
133+
str: The ECR URI for the corresponding SageMaker Docker image.
134+
135+
Raises:
136+
NotImplementedError: If the scope is not supported.
137+
ValueError: If the combination of arguments specified is not supported or
138+
any PipelineVariable object is passed in.
139+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
140+
known security vulnerabilities.
141+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
142+
"""
143+
config = _config_for_framework_and_scope(self.framework.value,
144+
self.image_scope.value,
145+
self.accelerator_type)
146+
147+
original_version = self.version
148+
try:
149+
version = _validate_version_and_set_if_needed(self.version, config, self.framework.value)
150+
except ValueError:
151+
version = None
152+
if not version:
153+
version = self._fetch_latest_version_from_config(config)
154+
155+
version_config = config["versions"][_version_for_config(version, config)]
156+
157+
if "huggingface" in self.framework.value:
158+
if version_config.get("version_aliases"):
159+
full_base_framework_version = version_config["version_aliases"].get(
160+
self.base_framework_version, self.base_framework_version
161+
)
162+
_validate_arg(full_base_framework_version, list(version_config.keys()), "base framework")
163+
version_config = version_config.get(full_base_framework_version)
164+
165+
self.py_version = _validate_py_version_and_set_if_needed(self.py_version,
166+
version_config,
167+
self.framework.value)
168+
version_config = version_config.get(self.py_version) or version_config
169+
170+
registry = _registry_from_region(self.region, version_config["registries"])
171+
endpoint_data = utils._botocore_resolver().construct_endpoint("ecr", self.region)
172+
if self.region == "il-central-1" and not endpoint_data:
173+
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(self.region)}
174+
hostname = endpoint_data["hostname"]
175+
176+
repo = version_config["repository"]
177+
178+
# if container version is available in .json file, utilize that
179+
if version_config.get("container_version"):
180+
self.container_version = version_config["container_version"][self.processor.value]
181+
182+
# Append sdk version in case of trainium instances
183+
if repo in ["pytorch-training-neuron"]:
184+
if not self.sdk_version:
185+
sdk_version = _get_latest_versions(version_config["sdk_versions"])
186+
self.container_version = self.sdk_version + "-" + self.container_version
187+
188+
if self.framework == Framework.HUGGING_FACE:
189+
pt_or_tf_version = (
190+
re.compile("^(pytorch|tensorflow)(.*)$").match(self.base_framework_version).group(2)
191+
)
192+
_version = original_version
193+
194+
if repo in [
195+
"huggingface-pytorch-trcomp-training",
196+
"huggingface-tensorflow-trcomp-training",
197+
]:
198+
_version = version
199+
if repo in [
200+
"huggingface-pytorch-inference-neuron",
201+
"huggingface-pytorch-inference-neuronx",
202+
]:
203+
if not sdk_version:
204+
self.sdk_version = _get_latest_versions(version_config["sdk_versions"])
205+
self.container_version = self.sdk_version + "-" + self.container_version
206+
if config.get("version_aliases").get(original_version):
207+
_version = config.get("version_aliases")[original_version]
208+
if (
209+
config.get("versions", {})
210+
.get(_version, {})
211+
.get("version_aliases", {})
212+
.get(self.base_framework_version, {})
213+
):
214+
_base_framework_version = config.get("versions")[_version]["version_aliases"][
215+
self.base_framework_version
216+
]
217+
pt_or_tf_version = (
218+
re.compile("^(pytorch|tensorflow)(.*)$").match(_base_framework_version).group(2)
219+
)
220+
221+
tag_prefix = f"{pt_or_tf_version}-transformers{_version}"
222+
else:
223+
tag_prefix = version_config.get("tag_prefix", version)
224+
225+
if repo == f"{self.framework.value}-inference-graviton":
226+
self.container_version = f"{self.container_version}-sagemaker"
227+
_validate_instance_deprecation(self.framework,
228+
self.instance_type,
229+
version)
230+
231+
tag = _get_image_tag(
232+
self.container_version,
233+
self.distribution,
234+
self.image_scope.value,
235+
self.framework,
236+
self.inference_tool,
237+
self.instance_type,
238+
self.processor.value,
239+
self.py_version,
240+
tag_prefix,
241+
version)
242+
243+
if tag:
244+
repo += ":{}".format(tag)
245+
246+
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
247+
248+
def _fetch_latest_version_from_config(self,
249+
framework_config: dict) -> str:
250+
if self.image_scope.value in framework_config:
251+
if image_scope_config := framework_config[self.image_scope.value]:
252+
if version_aliases := image_scope_config["version_aliases"]:
253+
if latest_version := version_aliases["latest"]:
254+
return latest_version
255+
versions = list(framework_config["versions"].keys())
256+
top_version = versions[0]
257+
bottom_version = versions[-1]
258+
259+
if Version(top_version) >= Version(bottom_version):
260+
return top_version
261+
return bottom_version

tests/unit/sagemaker/modules/__init__.py

Whitespace-only changes.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import unittest
2+
3+
from sagemaker.modules.image_spec import ImageSpec, Framework, Processor
4+
5+
6+
class TestImageSpec(unittest.TestCase):
7+
8+
def test_image_spec_update(self):
9+
image_spec = ImageSpec(framework=Framework.HUGGING_FACE)
10+
assert image_spec.version == None
11+
image_spec.update_image_spec(version="v3")
12+
assert image_spec.version == "v3"
13+
14+
def test_image_spec_retrive(self):
15+
# Asserting substrings because full string uri can change with newer versions
16+
image_spec = ImageSpec(framework=Framework.XGBOOST)
17+
xgboost_uri = image_spec.retrieve()
18+
assert "dkr.ecr.us-west-2.amazonaws.com" in xgboost_uri
19+
assert "sagemaker-xgboost" in xgboost_uri
20+
21+
image_spec = ImageSpec(framework=Framework.HUGGING_FACE,
22+
processor=Processor.GPU,
23+
base_framework_version="pytorch2.1.0")
24+
hugging_face_uri = image_spec.retrieve()
25+
assert "dkr.ecr.us-west-2.amazonaws.com" in hugging_face_uri
26+
assert "huggingface-pytorch-training" in hugging_face_uri
27+
28+
image_spec = ImageSpec(framework=Framework.PYTORCH)
29+
pytorch_uri = image_spec.retrieve()
30+
assert "dkr.ecr.us-west-2.amazonaws.com" in pytorch_uri
31+
assert "pytorch-training" in pytorch_uri
32+
33+
image_spec = ImageSpec(framework=Framework.SKLEARN)
34+
sklearn_uri = image_spec.retrieve()
35+
assert "dkr.ecr.us-west-2.amazonaws.com" in sklearn_uri
36+
assert "sagemaker-scikit-learn" in sklearn_uri
37+
38+
def test_image_spec_retrive_with_version(self):
39+
image_spec = ImageSpec(framework=Framework.XGBOOST,
40+
version="0.90-1")
41+
xgboost_uri = image_spec.retrieve()
42+
assert xgboost_uri == '246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-1-cpu-py3'

0 commit comments

Comments
 (0)