Skip to content

Commit b92aa3c

Browse files
authored
feat: tag JumpStart resource with config names (#4608)
* tag config name * format * resolving comments * format * format * update * fix * format * updates inference component config name * fix: tests
1 parent d06c4e5 commit b92aa3c

File tree

13 files changed

+315
-165
lines changed

13 files changed

+315
-165
lines changed

src/sagemaker/jumpstart/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class JumpStartTag(str, Enum):
9292
MODEL_ID = "sagemaker-sdk:jumpstart-model-id"
9393
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
9494
MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type"
95+
MODEL_CONFIG_NAME = "sagemaker-sdk:jumpstart-model-config-name"
9596

9697

9798
class SerializerType(str, Enum):

src/sagemaker/jumpstart/estimator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs
3535
from sagemaker.jumpstart.factory.model import get_default_predictor
36-
from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job
36+
from sagemaker.jumpstart.session_utils import get_model_info_from_training_job
3737
from sagemaker.jumpstart.types import JumpStartMetadataConfig
3838
from sagemaker.jumpstart.utils import (
3939
get_jumpstart_configs,
@@ -730,10 +730,10 @@ def attach(
730730
ValueError: if the model ID or version cannot be inferred from the training job.
731731
732732
"""
733-
733+
config_name = None
734734
if model_id is None:
735735

736-
model_id, model_version = get_model_id_version_from_training_job(
736+
model_id, model_version, config_name = get_model_info_from_training_job(
737737
training_job_name=training_job_name, sagemaker_session=sagemaker_session
738738
)
739739

@@ -749,6 +749,7 @@ def attach(
749749
tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated
750750
tolerate_vulnerable_model=True, # model is already trained, so tolerate if vulnerable
751751
sagemaker_session=sagemaker_session,
752+
config_name=config_name,
752753
)
753754

754755
# eula was already accepted if the model was successfully trained
@@ -1102,7 +1103,7 @@ def deploy(
11021103
tolerate_deprecated_model=self.tolerate_deprecated_model,
11031104
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
11041105
sagemaker_session=self.sagemaker_session,
1105-
# config_name=self.config_name,
1106+
config_name=self.config_name,
11061107
)
11071108

11081109
# If a predictor class was passed, do not mutate predictor

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima
478478

479479
if kwargs.sagemaker_session.settings.include_jumpstart_tags:
480480
kwargs.tags = add_jumpstart_model_id_version_tags(
481-
kwargs.tags, kwargs.model_id, full_model_version
481+
kwargs.tags,
482+
kwargs.model_id,
483+
full_model_version,
484+
config_name=kwargs.config_name,
482485
)
483486
return kwargs
484487

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
496496

497497
if kwargs.sagemaker_session.settings.include_jumpstart_tags:
498498
kwargs.tags = add_jumpstart_model_id_version_tags(
499-
kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type
499+
kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type, kwargs.config_name
500500
)
501501

502502
return kwargs

src/sagemaker/jumpstart/session_utils.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
from sagemaker.utils import aws_partition
2323

2424

25-
def get_model_id_version_from_endpoint(
25+
def get_model_info_from_endpoint(
2626
endpoint_name: str,
2727
inference_component_name: Optional[str] = None,
2828
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
29-
) -> Tuple[str, str, Optional[str]]:
30-
"""Given an endpoint and optionally inference component names, return the model ID and version.
29+
) -> Tuple[str, str, Optional[str], Optional[str]]:
30+
"""Optionally inference component names, return the model ID, version and config name.
3131
3232
Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
3333
and version. A third string element is included in the tuple for any inferred inference
@@ -46,30 +46,32 @@ def get_model_id_version_from_endpoint(
4646
(
4747
model_id,
4848
model_version,
49-
) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301
49+
config_name,
50+
) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301
5051
inference_component_name, sagemaker_session
5152
)
5253

5354
else:
5455
(
5556
model_id,
5657
model_version,
58+
config_name,
5759
inference_component_name,
58-
) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301
60+
) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301
5961
endpoint_name, sagemaker_session
6062
)
6163

6264
else:
63-
model_id, model_version = _get_model_id_version_from_model_based_endpoint(
65+
model_id, model_version, config_name = _get_model_info_from_model_based_endpoint(
6466
endpoint_name, inference_component_name, sagemaker_session
6567
)
66-
return model_id, model_version, inference_component_name
68+
return model_id, model_version, inference_component_name, config_name
6769

6870

69-
def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name(
71+
def _get_model_info_from_inference_component_endpoint_without_inference_component_name(
7072
endpoint_name: str, sagemaker_session: Session
71-
) -> Tuple[str, str, str]:
72-
"""Given an endpoint name, derives the model ID, version, and inferred inference component name.
73+
) -> Tuple[str, str, str, str]:
74+
"""Derives the model ID, version, config name and inferred inference component name.
7375
7476
This function assumes the endpoint corresponds to an inference-component-based endpoint.
7577
An endpoint is inference-component-based if and only if the associated endpoint config
@@ -98,14 +100,14 @@ def _get_model_id_version_from_inference_component_endpoint_without_inference_co
98100
)
99101
inference_component_name = inference_component_names[0]
100102
return (
101-
*_get_model_id_version_from_inference_component_endpoint_with_inference_component_name(
103+
*_get_model_info_from_inference_component_endpoint_with_inference_component_name(
102104
inference_component_name, sagemaker_session
103105
),
104106
inference_component_name,
105107
)
106108

107109

108-
def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name(
110+
def _get_model_info_from_inference_component_endpoint_with_inference_component_name(
109111
inference_component_name: str, sagemaker_session: Session
110112
):
111113
"""Returns the model ID and version inferred from a SageMaker inference component.
@@ -123,7 +125,7 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo
123125
f"inference-component/{inference_component_name}"
124126
)
125127

126-
model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
128+
model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn(
127129
inference_component_arn, sagemaker_session
128130
)
129131

@@ -134,15 +136,15 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo
134136
"when retrieving default predictor for this inference component."
135137
)
136138

137-
return model_id, model_version
139+
return model_id, model_version, config_name
138140

139141

140-
def _get_model_id_version_from_model_based_endpoint(
142+
def _get_model_info_from_model_based_endpoint(
141143
endpoint_name: str,
142144
inference_component_name: Optional[str],
143145
sagemaker_session: Session,
144-
) -> Tuple[str, str]:
145-
"""Returns the model ID and version inferred from a model-based endpoint.
146+
) -> Tuple[str, str, Optional[str]]:
147+
"""Returns the model ID, version and config name inferred from a model-based endpoint.
146148
147149
Raises:
148150
ValueError: If an inference component name is supplied, or if the endpoint does
@@ -161,7 +163,7 @@ def _get_model_id_version_from_model_based_endpoint(
161163

162164
endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}"
163165

164-
model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
166+
model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn(
165167
endpoint_arn, sagemaker_session
166168
)
167169

@@ -172,14 +174,14 @@ def _get_model_id_version_from_model_based_endpoint(
172174
"predictor for this endpoint."
173175
)
174176

175-
return model_id, model_version
177+
return model_id, model_version, config_name
176178

177179

178-
def get_model_id_version_from_training_job(
180+
def get_model_info_from_training_job(
179181
training_job_name: str,
180182
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
181-
) -> Tuple[str, str]:
182-
"""Returns the model ID and version inferred from a training job.
183+
) -> Tuple[str, str, Optional[str]]:
184+
"""Returns the model ID and version and config name inferred from a training job.
183185
184186
Raises:
185187
ValueError: If the training job does not have tags from which the model ID
@@ -194,9 +196,11 @@ def get_model_id_version_from_training_job(
194196
f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}"
195197
)
196198

197-
model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn(
198-
training_job_arn, sagemaker_session
199-
)
199+
(
200+
model_id,
201+
inferred_model_version,
202+
config_name,
203+
) = get_jumpstart_model_id_version_from_resource_arn(training_job_arn, sagemaker_session)
200204

201205
model_version = inferred_model_version or None
202206

@@ -207,4 +211,4 @@ def get_model_id_version_from_training_job(
207211
"for this training job."
208212
)
209213

210-
return model_id, model_version
214+
return model_id, model_version, config_name

src/sagemaker/jumpstart/types.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,9 +1064,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
10641064
Dictionary representation of the config component.
10651065
"""
10661066
for field in json_obj.keys():
1067-
if field not in self.__slots__:
1068-
raise ValueError(f"Invalid component field: {field}")
1069-
setattr(self, field, json_obj[field])
1067+
if field in self.__slots__:
1068+
setattr(self, field, json_obj[field])
10701069

10711070

10721071
class JumpStartMetadataConfig(JumpStartDataHolderType):
@@ -1164,20 +1163,17 @@ def get_top_config_from_ranking(
11641163
) -> Optional[JumpStartMetadataConfig]:
11651164
"""Gets the best the config based on config ranking.
11661165
1166+
Fallback to use the ordering in config names if
1167+
ranking is not available.
11671168
Args:
11681169
ranking_name (str):
11691170
The ranking name that config priority is based on.
11701171
instance_type (Optional[str]):
11711172
The instance type which the config selection is based on.
11721173
11731174
Raises:
1174-
ValueError: If the config exists but missing config ranking.
11751175
NotImplementedError: If the scope is unrecognized.
11761176
"""
1177-
if self.configs and (
1178-
not self.config_rankings or not self.config_rankings.get(ranking_name)
1179-
):
1180-
raise ValueError(f"Config exists but missing config ranking {ranking_name}.")
11811177

11821178
if self.scope == JumpStartScriptScope.INFERENCE:
11831179
instance_type_attribute = "supported_inference_instance_types"
@@ -1186,8 +1182,14 @@ def get_top_config_from_ranking(
11861182
else:
11871183
raise NotImplementedError(f"Unknown script scope {self.scope}")
11881184

1189-
rankings = self.config_rankings.get(ranking_name)
1190-
for config_name in rankings.rankings:
1185+
if self.configs and (
1186+
not self.config_rankings or not self.config_rankings.get(ranking_name)
1187+
):
1188+
ranked_config_names = sorted(list(self.configs.keys()))
1189+
else:
1190+
rankings = self.config_rankings.get(ranking_name)
1191+
ranked_config_names = rankings.rankings
1192+
for config_name in ranked_config_names:
11911193
resolved_config = self.configs[config_name].resolved_config
11921194
if instance_type and instance_type not in getattr(
11931195
resolved_config, instance_type_attribute

0 commit comments

Comments
 (0)