Skip to content

Commit 5d21a03

Browse files
authored
feat: retrieve jumpstart estimator and predictor without specifying model id (infer from tags) (#4304)
* feat: retrieve jumpstart estimator and predictor without specifying model id (infer from tags) * fix: pylint * chore: add support for ic-based endpoints * chore: update docstrings * chore: add integ tests * chore: add support in conftest for ic endpoints * fix: delete inference components * chore: address tagging and ic determination comments * chore: address PR comments * fix: docstring * chore: improve docs * chore: address comments * fix: list_and_paginate_inference_component_names_associated_with_endpoint in integ test cleanup * fix: boto3 session region * fix: boto_session * fix: sagemaker session to delete IC
1 parent f8d90ce commit 5d21a03

File tree

17 files changed

+1225
-38
lines changed

17 files changed

+1225
-38
lines changed

src/sagemaker/jumpstart/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,9 @@
240240
"Unable to create default JumpStart SageMaker Session due to the following error: %s.",
241241
str(e),
242242
)
243+
244+
EXTRA_MODEL_ID_TAGS = ["sm-jumpstart-id", "sagemaker-studio:jumpstart-model-id"]
245+
EXTRA_MODEL_VERSION_TAGS = [
246+
"sm-jumpstart-model-version",
247+
"sagemaker-studio:jumpstart-model-version",
248+
]

src/sagemaker/jumpstart/estimator.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs
3535
from sagemaker.jumpstart.factory.model import get_default_predictor
36+
from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job
3637
from sagemaker.jumpstart.utils import (
3738
is_valid_model_id,
3839
resolve_model_sagemaker_config_field,
@@ -668,8 +669,8 @@ def fit(
668669
def attach(
669670
cls,
670671
training_job_name: str,
671-
model_id: str,
672-
model_version: str = "*",
672+
model_id: Optional[str] = None,
673+
model_version: Optional[str] = None,
673674
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
674675
model_channel_name: str = "model",
675676
) -> "JumpStartEstimator":
@@ -711,8 +712,20 @@ def attach(
711712
Returns:
712713
Instance of the calling ``JumpStartEstimator`` Class with the attached
713714
training job.
715+
716+
Raises:
717+
ValueError: if the model ID or version cannot be inferred from the training job.
718+
714719
"""
715720

721+
if model_id is None:
722+
723+
model_id, model_version = get_model_id_version_from_training_job(
724+
training_job_name=training_job_name, sagemaker_session=sagemaker_session
725+
)
726+
727+
model_version = model_version or "*"
728+
716729
return cls._attach(
717730
training_job_name=training_job_name,
718731
sagemaker_session=sagemaker_session,
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module stores SageMaker Session utilities for JumpStart models."""
14+
15+
from __future__ import absolute_import
16+
17+
from typing import Optional, Tuple
18+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19+
20+
from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn
21+
from sagemaker.session import Session
22+
from sagemaker.utils import aws_partition
23+
24+
25+
def get_model_id_version_from_endpoint(
26+
endpoint_name: str,
27+
inference_component_name: Optional[str] = None,
28+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
29+
) -> Tuple[str, str, Optional[str]]:
30+
"""Given an endpoint and optionally inference component names, return the model ID and version.
31+
32+
Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
33+
and version. A third string element is included in the tuple for any inferred inference
34+
component name, or 'None' if it's a model-based endpoint.
35+
36+
JumpStart adds tags automatically to endpoints, models, endpoint configs, and inference
37+
components launched in SageMaker Studio and programmatically with the SageMaker Python SDK.
38+
39+
Raises:
40+
ValueError: If model ID and version cannot be inferred from the endpoint.
41+
"""
42+
if inference_component_name or sagemaker_session.is_inference_component_based_endpoint(
43+
endpoint_name
44+
):
45+
if inference_component_name:
46+
(
47+
model_id,
48+
model_version,
49+
) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301
50+
inference_component_name, sagemaker_session
51+
)
52+
53+
else:
54+
(
55+
model_id,
56+
model_version,
57+
inference_component_name,
58+
) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301
59+
endpoint_name, sagemaker_session
60+
)
61+
62+
else:
63+
model_id, model_version = _get_model_id_version_from_model_based_endpoint(
64+
endpoint_name, inference_component_name, sagemaker_session
65+
)
66+
return model_id, model_version, inference_component_name
67+
68+
69+
def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name(
70+
endpoint_name: str, sagemaker_session: Session
71+
) -> Tuple[str, str, str]:
72+
"""Given an endpoint name, derives the model ID, version, and inferred inference component name.
73+
74+
This function assumes the endpoint corresponds to an inference-component-based endpoint.
75+
An endpoint is inference-component-based if and only if the associated endpoint config
76+
has a role associated with it and no production variants with a ``ModelName`` field.
77+
78+
Raises:
79+
ValueError: If there is not a single inference component associated with the endpoint.
80+
"""
81+
inference_component_names = (
82+
sagemaker_session.list_and_paginate_inference_component_names_associated_with_endpoint(
83+
endpoint_name=endpoint_name
84+
)
85+
)
86+
87+
if len(inference_component_names) == 0:
88+
raise ValueError(
89+
f"No inference component found for the following endpoint: {endpoint_name}. "
90+
"Use ``SageMaker.CreateInferenceComponent`` to add inference components to "
91+
"your endpoint."
92+
)
93+
if len(inference_component_names) > 1:
94+
raise ValueError(
95+
f"Multiple inference components found for the following endpoint: {endpoint_name}. "
96+
"Provide an 'inference_component_name' to retrieve the model ID and version "
97+
"associated with a particular inference component."
98+
)
99+
inference_component_name = inference_component_names[0]
100+
return (
101+
*_get_model_id_version_from_inference_component_endpoint_with_inference_component_name(
102+
inference_component_name, sagemaker_session
103+
),
104+
inference_component_name,
105+
)
106+
107+
108+
def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name(
109+
inference_component_name: str, sagemaker_session: Session
110+
):
111+
"""Returns the model ID and version inferred from a SageMaker inference component.
112+
113+
Raises:
114+
ValueError: If the inference component does not have tags from which the model ID
115+
and version can be inferred.
116+
"""
117+
region: str = sagemaker_session.boto_region_name
118+
partition: str = aws_partition(region)
119+
account_id: str = sagemaker_session.account_id()
120+
121+
inference_component_arn = (
122+
f"arn:{partition}:sagemaker:{region}:{account_id}:"
123+
f"inference-component/{inference_component_name}"
124+
)
125+
126+
model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
127+
inference_component_arn, sagemaker_session
128+
)
129+
130+
if not model_id:
131+
raise ValueError(
132+
"Cannot infer JumpStart model ID from inference component "
133+
f"'{inference_component_name}'. Please specify JumpStart `model_id` "
134+
"when retrieving default predictor for this inference component."
135+
)
136+
137+
return model_id, model_version
138+
139+
140+
def _get_model_id_version_from_model_based_endpoint(
141+
endpoint_name: str,
142+
inference_component_name: Optional[str],
143+
sagemaker_session: Session,
144+
) -> Tuple[str, str]:
145+
"""Returns the model ID and version inferred from a model-based endpoint.
146+
147+
Raises:
148+
ValueError: If an inference component name is supplied, or if the endpoint does
149+
not have tags from which the model ID and version can be inferred.
150+
"""
151+
152+
if inference_component_name:
153+
raise ValueError("Cannot specify inference component name for model-based endpoints.")
154+
155+
region: str = sagemaker_session.boto_region_name
156+
partition: str = aws_partition(region)
157+
account_id: str = sagemaker_session.account_id()
158+
159+
# SageMaker Tagging requires endpoint names to be lower cased
160+
endpoint_name = endpoint_name.lower()
161+
162+
endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}"
163+
164+
model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
165+
endpoint_arn, sagemaker_session
166+
)
167+
168+
if not model_id:
169+
raise ValueError(
170+
f"Cannot infer JumpStart model ID from endpoint '{endpoint_name}'. "
171+
"Please specify JumpStart `model_id` when retrieving default "
172+
"predictor for this endpoint."
173+
)
174+
175+
return model_id, model_version
176+
177+
178+
def get_model_id_version_from_training_job(
179+
training_job_name: str,
180+
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
181+
) -> Tuple[str, str]:
182+
"""Returns the model ID and version inferred from a training job.
183+
184+
Raises:
185+
ValueError: If the training job does not have tags from which the model ID
186+
and version can be inferred. JumpStart adds tags automatically to training jobs
187+
launched in SageMaker Studio and programmatically with the SageMaker Python SDK.
188+
"""
189+
region: str = sagemaker_session.boto_region_name
190+
partition: str = aws_partition(region)
191+
account_id: str = sagemaker_session.account_id()
192+
193+
training_job_arn = (
194+
f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}"
195+
)
196+
197+
model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn(
198+
training_job_arn, sagemaker_session
199+
)
200+
201+
model_version = inferred_model_version or None
202+
203+
if not model_id:
204+
raise ValueError(
205+
f"Cannot infer JumpStart model ID from training job '{training_job_name}'. "
206+
"Please specify JumpStart `model_id` when retrieving Estimator "
207+
"for this training job."
208+
)
209+
210+
return model_id, model_version

src/sagemaker/jumpstart/utils.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515
import logging
1616
import os
17-
from typing import Any, Dict, List, Optional, Union
17+
from typing import Any, Dict, List, Optional, Tuple, Union
1818
from urllib.parse import urlparse
1919
import boto3
2020
from packaging.version import Version
@@ -762,3 +762,51 @@ def is_valid_model_id(
762762
if script == enums.JumpStartScriptScope.TRAINING:
763763
return model_id in model_id_set
764764
raise ValueError(f"Unsupported script: {script}")
765+
766+
767+
def get_jumpstart_model_id_version_from_resource_arn(
768+
resource_arn: str,
769+
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
770+
) -> Tuple[Optional[str], Optional[str]]:
771+
"""Returns the JumpStart model ID and version if in resource tags.
772+
773+
Returns 'None' if model ID or version cannot be inferred from tags.
774+
"""
775+
776+
list_tags_result = sagemaker_session.list_tags(resource_arn)
777+
778+
model_id: Optional[str] = None
779+
model_version: Optional[str] = None
780+
781+
model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS]
782+
model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS]
783+
784+
for model_id_key in model_id_keys:
785+
try:
786+
model_id_from_tag = get_tag_value(model_id_key, list_tags_result)
787+
except KeyError:
788+
continue
789+
if model_id_from_tag is not None:
790+
if model_id is not None and model_id_from_tag != model_id:
791+
constants.JUMPSTART_LOGGER.warning(
792+
"Found multiple model ID tags on the following resource: %s", resource_arn
793+
)
794+
model_id = None
795+
break
796+
model_id = model_id_from_tag
797+
798+
for model_version_key in model_version_keys:
799+
try:
800+
model_version_from_tag = get_tag_value(model_version_key, list_tags_result)
801+
except KeyError:
802+
continue
803+
if model_version_from_tag is not None:
804+
if model_version is not None and model_version_from_tag != model_version:
805+
constants.JUMPSTART_LOGGER.warning(
806+
"Found multiple model version tags on the following resource: %s", resource_arn
807+
)
808+
model_version = None
809+
break
810+
model_version = model_version_from_tag
811+
812+
return model_id, model_version

src/sagemaker/predictor.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
1818

1919
from sagemaker.jumpstart.factory.model import get_default_predictor
20-
from sagemaker.jumpstart.utils import is_jumpstart_model_input
20+
from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint
21+
2122

2223
from sagemaker.session import Session
2324

@@ -33,6 +34,7 @@
3334

3435
def retrieve_default(
3536
endpoint_name: str,
37+
inference_component_name: Optional[str] = None,
3638
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3739
region: Optional[str] = None,
3840
model_id: Optional[str] = None,
@@ -44,7 +46,9 @@ def retrieve_default(
4446
4547
Args:
4648
endpoint_name (str): Endpoint name for which to create a predictor.
47-
sagemaker_session (Session): The SageMaker Session to attach to the Predictor.
49+
inference_component_name (str): Name of the Amazon SageMaker inference component
50+
from which to optionally create a predictor. (Default: None).
51+
sagemaker_session (Session): The SageMaker Session to attach to the predictor.
4852
(Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
4953
region (str): The AWS Region for which to retrieve the default predictor.
5054
(Default: None).
@@ -63,16 +67,37 @@ def retrieve_default(
6367
Predictor: The default predictor to use for the model.
6468
6569
Raises:
66-
ValueError: If the combination of arguments specified is not supported.
70+
ValueError: If the combination of arguments specified is not supported, or if a model ID or
71+
version cannot be inferred from the endpoint.
6772
"""
6873

69-
if not is_jumpstart_model_input(model_id, model_version):
70-
raise ValueError(
71-
"Must specify JumpStart `model_id` and `model_version` "
72-
"when retrieving default predictor."
74+
if model_id is None:
75+
(
76+
inferred_model_id,
77+
inferred_model_version,
78+
inferred_inference_component_name,
79+
) = get_model_id_version_from_endpoint(
80+
endpoint_name, inference_component_name, sagemaker_session
7381
)
7482

75-
predictor = Predictor(endpoint_name=endpoint_name, sagemaker_session=sagemaker_session)
83+
if not inferred_model_id:
84+
raise ValueError(
85+
f"Cannot infer JumpStart model ID from endpoint '{endpoint_name}'. "
86+
"Please specify JumpStart `model_id` when retrieving default "
87+
"predictor for this endpoint."
88+
)
89+
90+
model_id = inferred_model_id
91+
model_version = model_version or inferred_model_version or "*"
92+
inference_component_name = inference_component_name or inferred_inference_component_name
93+
else:
94+
model_version = model_version or "*"
95+
96+
predictor = Predictor(
97+
endpoint_name=endpoint_name,
98+
component_name=inference_component_name,
99+
sagemaker_session=sagemaker_session,
100+
)
76101

77102
return get_default_predictor(
78103
predictor=predictor,

0 commit comments

Comments
 (0)