Skip to content

Commit d4612f8

Browse files
authored
feature: start new module for retrieving prebuilt SageMaker image URIs (#1701)
This commit also introduces configuration for Chainer URIs.
1 parent fa53af0 commit d4612f8

File tree

8 files changed

+525
-59
lines changed

8 files changed

+525
-59
lines changed

MANIFEST.in

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
recursive-include src/sagemaker *
1+
recursive-include src/sagemaker *.py
2+
3+
include src/sagemaker/image_uri_config/*.json
24

35
include VERSION
46
include LICENSE.txt

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def read_version():
8383
packages=find_packages("src"),
8484
package_dir={"": "src"},
8585
py_modules=[os.path.splitext(os.path.basename(path))[0] for path in glob("src/*.py")],
86+
include_package_data=True,
8687
long_description=read("README.rst"),
8788
author="Amazon Web Services",
8889
url="https://github.com/aws/sagemaker-python-sdk/",
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
{
2+
"processors": ["cpu", "gpu"],
3+
"version_aliases": {
4+
"4.0": "4.0.0",
5+
"4.1": "4.1.0",
6+
"5.0": "5.0.0"
7+
},
8+
"versions": {
9+
"4.0.0": {
10+
"registries": {
11+
"ap-east-1": "057415533634",
12+
"ap-northeast-1": "520713654638",
13+
"ap-northeast-2": "520713654638",
14+
"ap-south-1": "520713654638",
15+
"ap-southeast-1": "520713654638",
16+
"ap-southeast-2": "520713654638",
17+
"ca-central-1": "520713654638",
18+
"cn-north-1": "422961961927",
19+
"cn-northwest-1": "423003514399",
20+
"eu-central-1": "520713654638",
21+
"eu-north-1": "520713654638",
22+
"eu-west-1": "520713654638",
23+
"eu-west-2": "520713654638",
24+
"eu-west-3": "520713654638",
25+
"me-south-1": "724002660598",
26+
"sa-east-1": "520713654638",
27+
"us-east-1": "520713654638",
28+
"us-east-2": "520713654638",
29+
"us-gov-west-1": "246785580436",
30+
"us-iso-east-1": "744548109606",
31+
"us-west-1": "520713654638",
32+
"us-west-2": "520713654638"
33+
},
34+
"repository": "sagemaker-chainer",
35+
"py_versions": ["py2", "py3"]
36+
},
37+
"4.1.0": {
38+
"registries": {
39+
"ap-east-1": "057415533634",
40+
"ap-northeast-1": "520713654638",
41+
"ap-northeast-2": "520713654638",
42+
"ap-south-1": "520713654638",
43+
"ap-southeast-1": "520713654638",
44+
"ap-southeast-2": "520713654638",
45+
"ca-central-1": "520713654638",
46+
"cn-north-1": "422961961927",
47+
"cn-northwest-1": "423003514399",
48+
"eu-central-1": "520713654638",
49+
"eu-north-1": "520713654638",
50+
"eu-west-1": "520713654638",
51+
"eu-west-2": "520713654638",
52+
"eu-west-3": "520713654638",
53+
"me-south-1": "724002660598",
54+
"sa-east-1": "520713654638",
55+
"us-east-1": "520713654638",
56+
"us-east-2": "520713654638",
57+
"us-gov-west-1": "246785580436",
58+
"us-iso-east-1": "744548109606",
59+
"us-west-1": "520713654638",
60+
"us-west-2": "520713654638"
61+
},
62+
"repository": "sagemaker-chainer",
63+
"py_versions": ["py2", "py3"]
64+
},
65+
"5.0.0": {
66+
"registries": {
67+
"ap-east-1": "057415533634",
68+
"ap-northeast-1": "520713654638",
69+
"ap-northeast-2": "520713654638",
70+
"ap-south-1": "520713654638",
71+
"ap-southeast-1": "520713654638",
72+
"ap-southeast-2": "520713654638",
73+
"ca-central-1": "520713654638",
74+
"cn-north-1": "422961961927",
75+
"cn-northwest-1": "423003514399",
76+
"eu-central-1": "520713654638",
77+
"eu-north-1": "520713654638",
78+
"eu-west-1": "520713654638",
79+
"eu-west-2": "520713654638",
80+
"eu-west-3": "520713654638",
81+
"me-south-1": "724002660598",
82+
"sa-east-1": "520713654638",
83+
"us-east-1": "520713654638",
84+
"us-east-2": "520713654638",
85+
"us-gov-west-1": "246785580436",
86+
"us-iso-east-1": "744548109606",
87+
"us-west-1": "520713654638",
88+
"us-west-2": "520713654638"
89+
},
90+
"repository": "sagemaker-chainer",
91+
"py_versions": ["py2", "py3"]
92+
}
93+
}
94+
}

src/sagemaker/image_uris.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Functions for generating ECR image URIs for pre-built SageMaker Docker images."""
14+
from __future__ import absolute_import
15+
16+
import json
17+
import os
18+
19+
from sagemaker import utils
20+
21+
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}:{tag}"
22+
23+
24+
def retrieve(framework, region, version=None, py_version=None, instance_type=None):
25+
"""Retrieves the ECR URI for the Docker image matching the given arguments.
26+
27+
Args:
28+
framework (str): The name of the framework.
29+
region (str): The AWS region.
30+
version (str): The framework version. This is required if there is
31+
more than one supported version for the given framework.
32+
py_version (str): The Python version. This is required if there is
33+
more than one supported Python version for the given framework version.
34+
instance_type (str): The SageMaker instance type. For supported types, see
35+
https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if
36+
there are different images for different processor types.
37+
38+
Returns:
39+
str: the ECR URI for the corresponding SageMaker Docker image.
40+
41+
Raises:
42+
ValueError: If the framework version, Python version, processor type, or region is
43+
not supported given the other arguments.
44+
"""
45+
config = config_for_framework(framework)
46+
version_config = config["versions"][_version_for_config(version, config, framework)]
47+
48+
registry = _registry_from_region(region, version_config["registries"])
49+
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"]
50+
51+
repo = version_config["repository"]
52+
53+
_validate_py_version(py_version, version_config["py_versions"], framework, version)
54+
tag = "{}-{}-{}".format(version, _processor(instance_type, config["processors"]), py_version)
55+
56+
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo, tag=tag)
57+
58+
59+
def config_for_framework(framework):
60+
"""Loads the JSON config for the given framework."""
61+
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))
62+
with open(fname) as f:
63+
return json.load(f)
64+
65+
66+
def _version_for_config(version, config, framework):
67+
"""Returns the version string for retrieving a framework version's specific config."""
68+
if "version_aliases" in config:
69+
if version in config["version_aliases"].keys():
70+
return config["version_aliases"][version]
71+
72+
available_versions = config["versions"].keys()
73+
if version in available_versions:
74+
return version
75+
76+
raise ValueError(
77+
"Unsupported {} version: {}. "
78+
"You may need to upgrade your SDK version (pip install -U sagemaker) for newer versions. "
79+
"Supported version(s): {}.".format(framework, version, ", ".join(available_versions))
80+
)
81+
82+
83+
def _registry_from_region(region, registry_dict):
84+
"""Returns the ECR registry (AWS account number) for the given region."""
85+
available_regions = registry_dict.keys()
86+
if region not in available_regions:
87+
raise ValueError(
88+
"Unsupported region: {}. You may need to upgrade "
89+
"your SDK version (pip install -U sagemaker) for newer regions. "
90+
"Supported region(s): {}.".format(region, ", ".join(available_regions))
91+
)
92+
93+
return registry_dict[region]
94+
95+
96+
def _processor(instance_type, available_processors):
97+
"""Returns the processor type for the given instance type."""
98+
if instance_type.startswith("local"):
99+
processor = "cpu" if instance_type == "local" else "gpu"
100+
elif not instance_type.startswith("ml."):
101+
raise ValueError(
102+
"Invalid SageMaker instance type: {}. See: "
103+
"https://aws.amazon.com/sagemaker/pricing/instance-types".format(instance_type)
104+
)
105+
else:
106+
family = instance_type.split(".")[1]
107+
processor = "gpu" if family[0] in ("g", "p") else "cpu"
108+
109+
if processor in available_processors:
110+
return processor
111+
112+
raise ValueError(
113+
"Unsupported processor type: {} (for {}). "
114+
"Supported type(s): {}.".format(processor, instance_type, ", ".join(available_processors))
115+
)
116+
117+
118+
def _validate_py_version(py_version, available_versions, framework, fw_version):
119+
"""Checks if the Python version is one of the supported versions."""
120+
if py_version not in available_versions:
121+
raise ValueError(
122+
"Unsupported Python version for {} {}: {}. You may need to upgrade "
123+
"your SDK version (pip install -U sagemaker) for newer versions. "
124+
"Supported Python version(s): {}.".format(
125+
framework, fw_version, py_version, ", ".join(available_versions)
126+
)
127+
)

tests/conftest.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from botocore.config import Config
2222
from packaging.version import Version
2323

24-
from sagemaker import Session, utils
24+
from sagemaker import Session, image_uris, utils
2525
from sagemaker.local import LocalSession
2626
from sagemaker.rl import RLEstimator
2727

@@ -110,11 +110,6 @@ def custom_bucket_name(boto_session):
110110
return "{}-{}-{}".format(CUSTOM_BUCKET_NAME_PREFIX, region, account)
111111

112112

113-
@pytest.fixture(scope="module", params=["4.0", "4.0.0", "4.1", "4.1.0", "5.0", "5.0.0"])
114-
def chainer_version(request):
115-
return request.param
116-
117-
118113
@pytest.fixture(scope="module", params=["py2", "py3"])
119114
def chainer_py_version(request):
120115
return request.param
@@ -405,3 +400,12 @@ def pytest_generate_tests(metafunc):
405400
):
406401
params.append("ml.p2.xlarge")
407402
metafunc.parametrize("instance_type", params, scope="session")
403+
404+
for fw in ("chainer",):
405+
fixture_name = "{}_version".format(fw)
406+
if fixture_name in metafunc.fixturenames:
407+
config = image_uris.config_for_framework(fw)
408+
versions = list(config["versions"].keys()) + list(
409+
config.get("version_aliases", {}).keys()
410+
)
411+
metafunc.parametrize(fixture_name, versions, scope="session")
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker import image_uris
16+
17+
ACCOUNT = "520713654638"
18+
DOMAIN = "amazonaws.com"
19+
IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:{}-{}-{}"
20+
REGION = "us-west-2"
21+
22+
ALTERNATE_REGION_DOMAIN_AND_ACCOUNTS = (
23+
("ap-east-1", DOMAIN, "057415533634"),
24+
("cn-north-1", "amazonaws.com.cn", "422961961927"),
25+
("cn-northwest-1", "amazonaws.com.cn", "423003514399"),
26+
("me-south-1", DOMAIN, "724002660598"),
27+
("us-gov-west-1", DOMAIN, "246785580436"),
28+
("us-iso-east-1", "c2s.ic.gov", "744548109606"),
29+
)
30+
31+
32+
def _expected_uri(
33+
repo, fw_version, processor, py_version, account=ACCOUNT, region=REGION, domain=DOMAIN
34+
):
35+
return IMAGE_URI_FORMAT.format(account, region, domain, repo, fw_version, processor, py_version)
36+
37+
38+
def test_chainer(chainer_version, chainer_py_version):
39+
for instance_type, processor in (("ml.c4.xlarge", "cpu"), ("ml.p2.xlarge", "gpu")):
40+
uri = image_uris.retrieve(
41+
framework="chainer",
42+
region=REGION,
43+
version=chainer_version,
44+
py_version=chainer_py_version,
45+
instance_type=instance_type,
46+
)
47+
expected = _expected_uri(
48+
"sagemaker-chainer", chainer_version, processor, chainer_py_version
49+
)
50+
assert expected == uri
51+
52+
for region, domain, account in ALTERNATE_REGION_DOMAIN_AND_ACCOUNTS:
53+
uri = image_uris.retrieve(
54+
framework="chainer",
55+
region=region,
56+
version=chainer_version,
57+
py_version=chainer_py_version,
58+
instance_type=instance_type,
59+
)
60+
expected = _expected_uri(
61+
"sagemaker-chainer",
62+
chainer_version,
63+
processor,
64+
chainer_py_version,
65+
account=account,
66+
region=region,
67+
domain=domain,
68+
)
69+
assert expected == uri

0 commit comments

Comments
 (0)