Skip to content

Commit e2e3cb2

Browse files
authored
change: use image_uris.retrieve for Neo and Inferentia images (#1734)
1 parent 3e3bee8 commit e2e3cb2

14 files changed

+382
-171
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"processors": ["inf"],
3+
"scope": ["inference"],
4+
"versions": {
5+
"1.5.1": {
6+
"py_versions": ["py3"],
7+
"registries": {
8+
"us-east-1": "785573368785",
9+
"us-west-2": "301217895009"
10+
},
11+
"repository": "sagemaker-neo-mxnet"
12+
}
13+
}
14+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"processors": ["inf"],
3+
"scope": ["inference"],
4+
"versions": {
5+
"1.15.0": {
6+
"py_versions": ["py3"],
7+
"registries": {
8+
"us-east-1": "785573368785",
9+
"us-west-2": "301217895009"
10+
},
11+
"repository": "sagemaker-neo-tensorflow"
12+
}
13+
}
14+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
{
2+
"processors": ["cpu", "gpu"],
3+
"scope": ["inference"],
4+
"version_aliases": {
5+
"0.12.1": "1.5",
6+
"1.0.0": "1.5",
7+
"1.1.0": "1.5",
8+
"1.2": "1.5",
9+
"1.2.0": "1.5",
10+
"1.2.1": "1.5",
11+
"1.3": "1.5",
12+
"1.3.0": "1.5",
13+
"1.4": "1.5",
14+
"1.4.0": "1.5",
15+
"1.4.1": "1.5"
16+
},
17+
"versions": {
18+
"1.5": {
19+
"py_versions": ["py3"],
20+
"registries": {
21+
"ap-east-1": "110948597952",
22+
"ap-northeast-1": "941853720454",
23+
"ap-northeast-2": "151534178276",
24+
"ap-south-1": "763008648453",
25+
"ap-southeast-1": "324986816169",
26+
"ap-southeast-2": "355873309152",
27+
"ca-central-1": "464438896020",
28+
"cn-north-1": "472730292857",
29+
"cn-northwest-1": "474822919863",
30+
"eu-central-1": "746233611703",
31+
"eu-north-1": "601324751636",
32+
"eu-west-1": "802834080501",
33+
"eu-west-2": "205493899709",
34+
"eu-west-3": "254080097072",
35+
"me-south-1": "836785723513",
36+
"sa-east-1": "756306329178",
37+
"us-east-1": "785573368785",
38+
"us-east-2": "007439368137",
39+
"us-gov-west-1": "263933020539",
40+
"us-west-1": "710691900526",
41+
"us-west-2": "301217895009"
42+
},
43+
"repository": "sagemaker-neo-mxnet"
44+
}
45+
}
46+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
{
2+
"processors": ["cpu", "gpu"],
3+
"scope": ["inference"],
4+
"version_aliases": {
5+
"0.4.0": "1.4.0",
6+
"1.0.0": "1.4.0",
7+
"1.1.0": "1.4.0",
8+
"1.2.0": "1.4.0",
9+
"1.3.0": "1.4.0"
10+
},
11+
"versions": {
12+
"1.4.0": {
13+
"py_versions": ["py3"],
14+
"registries": {
15+
"ap-east-1": "110948597952",
16+
"ap-northeast-1": "941853720454",
17+
"ap-northeast-2": "151534178276",
18+
"ap-south-1": "763008648453",
19+
"ap-southeast-1": "324986816169",
20+
"ap-southeast-2": "355873309152",
21+
"ca-central-1": "464438896020",
22+
"cn-north-1": "472730292857",
23+
"cn-northwest-1": "474822919863",
24+
"eu-central-1": "746233611703",
25+
"eu-north-1": "601324751636",
26+
"eu-west-1": "802834080501",
27+
"eu-west-2": "205493899709",
28+
"eu-west-3": "254080097072",
29+
"me-south-1": "836785723513",
30+
"sa-east-1": "756306329178",
31+
"us-east-1": "785573368785",
32+
"us-east-2": "007439368137",
33+
"us-gov-west-1": "263933020539",
34+
"us-west-1": "710691900526",
35+
"us-west-2": "301217895009"
36+
},
37+
"repository": "sagemaker-neo-pytorch"
38+
}
39+
}
40+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
{
2+
"processors": ["cpu", "gpu"],
3+
"scope": ["inference"],
4+
"version_aliases": {
5+
"1.4.1": "1.15.0",
6+
"1.5.0": "1.15.0",
7+
"1.6.0": "1.15.0",
8+
"1.7.0": "1.15.0",
9+
"1.8.0": "1.15.0",
10+
"1.9.0": "1.15.0",
11+
"1.10.0": "1.15.0",
12+
"1.11.0": "1.15.0",
13+
"1.12.0": "1.15.0",
14+
"1.13.0": "1.15.0",
15+
"1.14.0": "1.15.0"
16+
},
17+
"versions": {
18+
"1.15.0": {
19+
"py_versions": ["py3"],
20+
"registries": {
21+
"ap-east-1": "110948597952",
22+
"ap-northeast-1": "941853720454",
23+
"ap-northeast-2": "151534178276",
24+
"ap-south-1": "763008648453",
25+
"ap-southeast-1": "324986816169",
26+
"ap-southeast-2": "355873309152",
27+
"ca-central-1": "464438896020",
28+
"cn-north-1": "472730292857",
29+
"cn-northwest-1": "474822919863",
30+
"eu-central-1": "746233611703",
31+
"eu-north-1": "601324751636",
32+
"eu-west-1": "802834080501",
33+
"eu-west-2": "205493899709",
34+
"eu-west-3": "254080097072",
35+
"me-south-1": "836785723513",
36+
"sa-east-1": "756306329178",
37+
"us-east-1": "785573368785",
38+
"us-east-2": "007439368137",
39+
"us-gov-west-1": "263933020539",
40+
"us-west-1": "710691900526",
41+
"us-west-2": "301217895009"
42+
},
43+
"repository": "sagemaker-neo-tensorflow"
44+
}
45+
}
46+
}

src/sagemaker/image_uris.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import logging
1818
import os
19+
import re
1920

2021
from sagemaker import utils
2122

@@ -166,14 +167,28 @@ def _processor(instance_type, available_processors):
166167

167168
if instance_type.startswith("local"):
168169
processor = "cpu" if instance_type == "local" else "gpu"
169-
elif not instance_type.startswith("ml."):
170-
raise ValueError(
171-
"Invalid SageMaker instance type: {}. For options, see: "
172-
"https://aws.amazon.com/sagemaker/pricing/instance-types".format(instance_type)
173-
)
174170
else:
175-
family = instance_type.split(".")[1]
176-
processor = "gpu" if family[0] in ("g", "p") else "cpu"
171+
# looks for either "ml.<family>.<size>" or "ml_<family>"
172+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
173+
if match:
174+
family = match[1]
175+
176+
# For some frameworks, we have optimized images for specific families, e.g c5 or p3.
177+
# In those cases, we use the family name in the image tag. In other cases, we use
178+
# 'cpu' or 'gpu'.
179+
if family in available_processors:
180+
processor = family
181+
elif family.startswith("inf"):
182+
processor = "inf"
183+
elif family[0] in ("g", "p"):
184+
processor = "gpu"
185+
else:
186+
processor = "cpu"
187+
else:
188+
raise ValueError(
189+
"Invalid SageMaker instance type: {}. For options, see: "
190+
"https://aws.amazon.com/sagemaker/pricing/instance-types".format(instance_type)
191+
)
177192

178193
_validate_arg(processor, available_processors, "processor")
179194
return processor

src/sagemaker/model.py

Lines changed: 20 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919

2020
import sagemaker
21-
from sagemaker import fw_utils, local, session, utils, git_utils
21+
from sagemaker import fw_utils, image_uris, local, session, utils, git_utils
2222
from sagemaker.fw_utils import UploadedCode
2323
from sagemaker.transformer import Transformer
2424

@@ -28,32 +28,6 @@
2828
["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite"]
2929
)
3030

31-
NEO_IMAGE_ACCOUNT = {
32-
"us-west-1": "710691900526",
33-
"us-west-2": "301217895009",
34-
"us-east-1": "785573368785",
35-
"us-east-2": "007439368137",
36-
"eu-west-1": "802834080501",
37-
"eu-west-2": "205493899709",
38-
"eu-west-3": "254080097072",
39-
"eu-central-1": "746233611703",
40-
"eu-north-1": "601324751636",
41-
"ap-northeast-1": "941853720454",
42-
"ap-northeast-2": "151534178276",
43-
"ap-east-1": "110948597952",
44-
"ap-southeast-1": "324986816169",
45-
"ap-southeast-2": "355873309152",
46-
"ap-south-1": "763008648453",
47-
"sa-east-1": "756306329178",
48-
"ca-central-1": "464438896020",
49-
"me-south-1": "836785723513",
50-
"cn-north-1": "472730292857",
51-
"cn-northwest-1": "474822919863",
52-
"us-gov-west-1": "263933020539",
53-
}
54-
55-
INFERENTIA_INSTANCE_PREFIX = "ml_inf"
56-
5731

5832
class Model(object):
5933
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
@@ -243,7 +217,7 @@ def _compilation_job_config(
243217
"DataInputConfig": input_shape
244218
if not isinstance(input_shape, dict)
245219
else json.dumps(input_shape),
246-
"Framework": framework,
220+
"Framework": framework.upper(),
247221
}
248222
role = self.sagemaker_session.expand_role(role)
249223
output_model_config = {
@@ -260,64 +234,23 @@ def _compilation_job_config(
260234
"job_name": job_name,
261235
}
262236

263-
def check_neo_region(self, region):
264-
"""Check if this ``Model`` in the available region where neo support.
265-
266-
Args:
267-
region (str): Specifies the region where want to execute compilation
268-
269-
Returns:
270-
bool: boolean value whether if neo is available in the specified
271-
region
272-
"""
273-
if region in NEO_IMAGE_ACCOUNT:
274-
return True
275-
return False
237+
def _compilation_image_uri(self, region, target_instance_type, framework, framework_version):
238+
"""Retrieve the Neo or Inferentia image URI.
276239
277-
def _neo_image_account(self, region):
278-
"""
279-
Args:
280-
region:
281-
"""
282-
if region not in NEO_IMAGE_ACCOUNT:
283-
raise ValueError(
284-
"Neo is not currently supported in {}, "
285-
"valid regions: {}".format(region, NEO_IMAGE_ACCOUNT.keys())
286-
)
287-
return NEO_IMAGE_ACCOUNT[region]
288-
289-
def _neo_image_uri(self, region, target_instance_type, framework, framework_version):
290-
"""
291240
Args:
292-
region:
293-
target_instance_type:
294-
framework:
295-
framework_version:
296-
"""
297-
return fw_utils.create_image_uri(
298-
region,
299-
"neo-" + framework.lower(),
300-
target_instance_type.replace("_", "."),
301-
framework_version,
302-
py_version="py3",
303-
account=self._neo_image_account(region),
304-
)
305-
306-
def _inferentia_image_uri(self, region, target_instance_type, framework, framework_version):
241+
region (str): The AWS region.
242+
target_instance_type (str): Identifies the device on which you want to run
243+
your model after compilation, for example: ml_c5. For valid values, see
244+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
245+
framework (str): The framework name.
246+
framework_version (str): The framework version.
307247
"""
308-
Args:
309-
region:
310-
target_instance_type:
311-
framework:
312-
framework_version:
313-
"""
314-
return fw_utils.create_image_uri(
248+
framework_prefix = "inferentia-" if target_instance_type.startswith("ml_inf") else "neo-"
249+
return image_uris.retrieve(
250+
"{}{}".format(framework_prefix, framework),
315251
region,
316-
"neo-" + framework.lower(),
317-
target_instance_type.replace("_", "."),
318-
framework_version,
319-
py_version="py3",
320-
account=self._neo_image_account(region),
252+
instance_type=target_instance_type,
253+
version=framework_version,
321254
)
322255

323256
def compile(
@@ -361,7 +294,7 @@ def compile(
361294
sagemaker.model.Model: A SageMaker ``Model`` object. See
362295
:func:`~sagemaker.model.Model` for full details.
363296
"""
364-
framework = self._framework() or framework
297+
framework = framework or self._framework()
365298
if framework is None:
366299
raise ValueError(
367300
"You must specify framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS)
@@ -375,8 +308,7 @@ def compile(
375308
if self.model_data is None:
376309
raise ValueError("You must provide an S3 path to the compressed model artifacts.")
377310

378-
framework = framework.upper()
379-
framework_version = self._get_framework_version() or framework_version
311+
framework_version = framework_version or self._get_framework_version()
380312

381313
self._init_sagemaker_session_if_does_not_exist(target_instance_family)
382314
config = self._compilation_job_config(
@@ -392,16 +324,9 @@ def compile(
392324
self.sagemaker_session.compile_model(**config)
393325
job_status = self.sagemaker_session.wait_for_compilation_job(job_name)
394326
self.model_data = job_status["ModelArtifacts"]["S3ModelArtifacts"]
327+
395328
if target_instance_family.startswith("ml_"):
396-
self.image_uri = self._neo_image_uri(
397-
self.sagemaker_session.boto_region_name,
398-
target_instance_family,
399-
framework,
400-
framework_version,
401-
)
402-
self._is_compiled_model = True
403-
elif target_instance_family.startswith(INFERENTIA_INSTANCE_PREFIX):
404-
self.image_uri = self._inferentia_image_uri(
329+
self.image_uri = self._compilation_image_uri(
405330
self.sagemaker_session.boto_region_name,
406331
target_instance_family,
407332
framework,
@@ -414,6 +339,7 @@ def compile(
414339
"please deploy the model manually.",
415340
target_instance_family,
416341
)
342+
417343
return self
418344

419345
def deploy(

0 commit comments

Comments
 (0)