Skip to content

Commit 92219f2

Browse files
Captainiabenieric
authored andcommitted
Use separate tags for inference and training configs (aws#4635)
* Use separate tags for inference and training * format * format * format * format
1 parent 7b5ef04 commit 92219f2

File tree

14 files changed

+417
-180
lines changed

14 files changed

+417
-180
lines changed

src/sagemaker/jumpstart/enums.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ class JumpStartTag(str, Enum):
9292
MODEL_ID = "sagemaker-sdk:jumpstart-model-id"
9393
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
9494
MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type"
95-
MODEL_CONFIG_NAME = "sagemaker-sdk:jumpstart-model-config-name"
95+
96+
INFERENCE_CONFIG_NAME = "sagemaker-sdk:jumpstart-inference-config-name"
97+
TRAINING_CONFIG_NAME = "sagemaker-sdk:jumpstart-training-config-name"
9698

9799

98100
class SerializerType(str, Enum):

src/sagemaker/jumpstart/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ def attach(
737737
config_name = None
738738
if model_id is None:
739739

740-
model_id, model_version, config_name = get_model_info_from_training_job(
740+
model_id, model_version, _, config_name = get_model_info_from_training_job(
741741
training_job_name=training_job_name, sagemaker_session=sagemaker_session
742742
)
743743

@@ -1143,7 +1143,9 @@ def set_training_config(self, config_name: str) -> None:
11431143
Args:
11441144
config_name (str): The name of the config.
11451145
"""
1146-
self.__init__(**self.init_kwargs, config_name=config_name)
1146+
self.__init__(
1147+
model_id=self.model_id, model_version=self.model_version, config_name=config_name
1148+
)
11471149

11481150
def __str__(self) -> str:
11491151
"""Overriding str(*) method to make more human-readable."""

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
JumpStartModelInitKwargs,
6262
)
6363
from sagemaker.jumpstart.utils import (
64-
add_jumpstart_model_id_version_tags,
64+
add_jumpstart_model_info_tags,
6565
get_eula_message,
6666
update_dict_if_key_not_present,
6767
resolve_estimator_sagemaker_config_field,
@@ -479,11 +479,12 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima
479479
).version
480480

481481
if kwargs.sagemaker_session.settings.include_jumpstart_tags:
482-
kwargs.tags = add_jumpstart_model_id_version_tags(
482+
kwargs.tags = add_jumpstart_model_info_tags(
483483
kwargs.tags,
484484
kwargs.model_id,
485485
full_model_version,
486486
config_name=kwargs.config_name,
487+
scope=JumpStartScriptScope.TRAINING,
487488
)
488489
return kwargs
489490

src/sagemaker/jumpstart/factory/model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
JumpStartModelRegisterKwargs,
4545
)
4646
from sagemaker.jumpstart.utils import (
47-
add_jumpstart_model_id_version_tags,
47+
add_jumpstart_model_info_tags,
4848
update_dict_if_key_not_present,
4949
resolve_model_sagemaker_config_field,
5050
verify_model_region_and_return_specs,
@@ -495,8 +495,13 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
495495
).version
496496

497497
if kwargs.sagemaker_session.settings.include_jumpstart_tags:
498-
kwargs.tags = add_jumpstart_model_id_version_tags(
499-
kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type, kwargs.config_name
498+
kwargs.tags = add_jumpstart_model_info_tags(
499+
kwargs.tags,
500+
kwargs.model_id,
501+
full_model_version,
502+
kwargs.model_type,
503+
config_name=kwargs.config_name,
504+
scope=JumpStartScriptScope.INFERENCE,
500505
)
501506

502507
return kwargs

src/sagemaker/jumpstart/session_utils.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Optional, Tuple
1818
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
1919

20-
from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn
20+
from sagemaker.jumpstart.utils import get_jumpstart_model_info_from_resource_arn
2121
from sagemaker.session import Session
2222
from sagemaker.utils import aws_partition
2323

@@ -26,7 +26,7 @@ def get_model_info_from_endpoint(
2626
endpoint_name: str,
2727
inference_component_name: Optional[str] = None,
2828
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
29-
) -> Tuple[str, str, Optional[str], Optional[str]]:
29+
) -> Tuple[str, str, Optional[str], Optional[str], Optional[str]]:
3030
"""Optionally inference component names, return the model ID, version and config name.
3131
3232
Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
@@ -46,7 +46,8 @@ def get_model_info_from_endpoint(
4646
(
4747
model_id,
4848
model_version,
49-
config_name,
49+
inference_config_name,
50+
training_config_name,
5051
) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301
5152
inference_component_name, sagemaker_session
5253
)
@@ -55,17 +56,29 @@ def get_model_info_from_endpoint(
5556
(
5657
model_id,
5758
model_version,
58-
config_name,
59+
inference_config_name,
60+
training_config_name,
5961
inference_component_name,
6062
) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301
6163
endpoint_name, sagemaker_session
6264
)
6365

6466
else:
65-
model_id, model_version, config_name = _get_model_info_from_model_based_endpoint(
67+
(
68+
model_id,
69+
model_version,
70+
inference_config_name,
71+
training_config_name,
72+
) = _get_model_info_from_model_based_endpoint(
6673
endpoint_name, inference_component_name, sagemaker_session
6774
)
68-
return model_id, model_version, inference_component_name, config_name
75+
return (
76+
model_id,
77+
model_version,
78+
inference_component_name,
79+
inference_config_name,
80+
training_config_name,
81+
)
6982

7083

7184
def _get_model_info_from_inference_component_endpoint_without_inference_component_name(
@@ -125,9 +138,12 @@ def _get_model_info_from_inference_component_endpoint_with_inference_component_n
125138
f"inference-component/{inference_component_name}"
126139
)
127140

128-
model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn(
129-
inference_component_arn, sagemaker_session
130-
)
141+
(
142+
model_id,
143+
model_version,
144+
inference_config_name,
145+
training_config_name,
146+
) = get_jumpstart_model_info_from_resource_arn(inference_component_arn, sagemaker_session)
131147

132148
if not model_id:
133149
raise ValueError(
@@ -136,14 +152,14 @@ def _get_model_info_from_inference_component_endpoint_with_inference_component_n
136152
"when retrieving default predictor for this inference component."
137153
)
138154

139-
return model_id, model_version, config_name
155+
return model_id, model_version, inference_config_name, training_config_name
140156

141157

142158
def _get_model_info_from_model_based_endpoint(
143159
endpoint_name: str,
144160
inference_component_name: Optional[str],
145161
sagemaker_session: Session,
146-
) -> Tuple[str, str, Optional[str]]:
162+
) -> Tuple[str, str, Optional[str], Optional[str]]:
147163
"""Returns the model ID, version and config name inferred from a model-based endpoint.
148164
149165
Raises:
@@ -163,9 +179,12 @@ def _get_model_info_from_model_based_endpoint(
163179

164180
endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}"
165181

166-
model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn(
167-
endpoint_arn, sagemaker_session
168-
)
182+
(
183+
model_id,
184+
model_version,
185+
inference_config_name,
186+
training_config_name,
187+
) = get_jumpstart_model_info_from_resource_arn(endpoint_arn, sagemaker_session)
169188

170189
if not model_id:
171190
raise ValueError(
@@ -174,13 +193,13 @@ def _get_model_info_from_model_based_endpoint(
174193
"predictor for this endpoint."
175194
)
176195

177-
return model_id, model_version, config_name
196+
return model_id, model_version, inference_config_name, training_config_name
178197

179198

180199
def get_model_info_from_training_job(
181200
training_job_name: str,
182201
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
183-
) -> Tuple[str, str, Optional[str]]:
202+
) -> Tuple[str, str, Optional[str], Optional[str]]:
184203
"""Returns the model ID and version and config name inferred from a training job.
185204
186205
Raises:
@@ -199,8 +218,9 @@ def get_model_info_from_training_job(
199218
(
200219
model_id,
201220
inferred_model_version,
202-
config_name,
203-
) = get_jumpstart_model_id_version_from_resource_arn(training_job_arn, sagemaker_session)
221+
inference_config_name,
222+
trainig_config_name,
223+
) = get_jumpstart_model_info_from_resource_arn(training_job_arn, sagemaker_session)
204224

205225
model_version = inferred_model_version or None
206226

@@ -211,4 +231,4 @@ def get_model_info_from_training_job(
211231
"for this training job."
212232
)
213233

214-
return model_id, model_version, config_name
234+
return model_id, model_version, inference_config_name, trainig_config_name

src/sagemaker/jumpstart/utils.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ def add_single_jumpstart_tag(
320320
tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags)
321321
or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags)
322322
or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags)
323-
or tag_key_in_array(enums.JumpStartTag.MODEL_CONFIG_NAME, curr_tags)
323+
or tag_key_in_array(enums.JumpStartTag.INFERENCE_CONFIG_NAME, curr_tags)
324+
or tag_key_in_array(enums.JumpStartTag.TRAINING_CONFIG_NAME, curr_tags)
324325
)
325326
if is_uri
326327
else False
@@ -351,12 +352,13 @@ def get_jumpstart_base_name_if_jumpstart_model(
351352
return None
352353

353354

354-
def add_jumpstart_model_id_version_tags(
355+
def add_jumpstart_model_info_tags(
355356
tags: Optional[List[TagsDict]],
356357
model_id: str,
357358
model_version: str,
358359
model_type: Optional[enums.JumpStartModelType] = None,
359360
config_name: Optional[str] = None,
361+
scope: enums.JumpStartScriptScope = None,
360362
) -> List[TagsDict]:
361363
"""Add custom model ID and version tags to JumpStart related resources."""
362364
if model_id is None or model_version is None:
@@ -380,10 +382,17 @@ def add_jumpstart_model_id_version_tags(
380382
tags,
381383
is_uri=False,
382384
)
383-
if config_name:
385+
if config_name and scope == enums.JumpStartScriptScope.INFERENCE:
384386
tags = add_single_jumpstart_tag(
385387
config_name,
386-
enums.JumpStartTag.MODEL_CONFIG_NAME,
388+
enums.JumpStartTag.INFERENCE_CONFIG_NAME,
389+
tags,
390+
is_uri=False,
391+
)
392+
if config_name and scope == enums.JumpStartScriptScope.TRAINING:
393+
tags = add_single_jumpstart_tag(
394+
config_name,
395+
enums.JumpStartTag.TRAINING_CONFIG_NAME,
387396
tags,
388397
is_uri=False,
389398
)
@@ -840,10 +849,10 @@ def _extract_value_from_list_of_tags(
840849
return resolved_value
841850

842851

843-
def get_jumpstart_model_id_version_from_resource_arn(
852+
def get_jumpstart_model_info_from_resource_arn(
844853
resource_arn: str,
845854
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
846-
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
855+
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
847856
"""Returns the JumpStart model ID, version and config name if in resource tags.
848857
849858
Returns 'None' if model ID or version or config name cannot be inferred from tags.
@@ -853,7 +862,8 @@ def get_jumpstart_model_id_version_from_resource_arn(
853862

854863
model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS]
855864
model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS]
856-
model_config_name_keys = [enums.JumpStartTag.MODEL_CONFIG_NAME]
865+
inference_config_name_keys = [enums.JumpStartTag.INFERENCE_CONFIG_NAME]
866+
training_config_name_keys = [enums.JumpStartTag.TRAINING_CONFIG_NAME]
857867

858868
model_id: Optional[str] = _extract_value_from_list_of_tags(
859869
tag_keys=model_id_keys,
@@ -869,14 +879,21 @@ def get_jumpstart_model_id_version_from_resource_arn(
869879
resource_arn=resource_arn,
870880
)
871881

872-
config_name: Optional[str] = _extract_value_from_list_of_tags(
873-
tag_keys=model_config_name_keys,
882+
inference_config_name: Optional[str] = _extract_value_from_list_of_tags(
883+
tag_keys=inference_config_name_keys,
884+
list_tags_result=list_tags_result,
885+
resource_name="inference config name",
886+
resource_arn=resource_arn,
887+
)
888+
889+
training_config_name: Optional[str] = _extract_value_from_list_of_tags(
890+
tag_keys=training_config_name_keys,
874891
list_tags_result=list_tags_result,
875-
resource_name="model config name",
892+
resource_name="training config name",
876893
resource_arn=resource_arn,
877894
)
878895

879-
return model_id, model_version, config_name
896+
return model_id, model_version, inference_config_name, training_config_name
880897

881898

882899
def get_region_fallback(

src/sagemaker/predictor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def retrieve_default(
8282
inferred_model_version,
8383
inferred_inference_component_name,
8484
inferred_config_name,
85+
_,
8586
) = get_model_info_from_endpoint(endpoint_name, inference_component_name, sagemaker_session)
8687

8788
if not inferred_model_id:

0 commit comments

Comments
 (0)