Skip to content

feat: tag JumpStart resource with config names #4608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class JumpStartTag(str, Enum):
MODEL_ID = "sagemaker-sdk:jumpstart-model-id"
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type"
MODEL_CONFIG_NAME = "sagemaker-sdk:jumpstart-model-config-name"


class SerializerType(str, Enum):
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,10 +730,10 @@ def attach(
ValueError: if the model ID or version cannot be inferred from the training job.

"""

config_name = None
if model_id is None:

model_id, model_version = get_model_id_version_from_training_job(
model_id, model_version, config_name = get_model_id_version_from_training_job(
training_job_name=training_job_name, sagemaker_session=sagemaker_session
)

Expand All @@ -749,6 +749,7 @@ def attach(
tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated
tolerate_vulnerable_model=True, # model is already trained, so tolerate if vulnerable
sagemaker_session=sagemaker_session,
config_name=config_name,
)

# eula was already accepted if the model was successfully trained
Expand Down Expand Up @@ -1102,7 +1103,7 @@ def deploy(
tolerate_deprecated_model=self.tolerate_deprecated_model,
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
sagemaker_session=self.sagemaker_session,
# config_name=self.config_name,
config_name=self.config_name,
)

# If a predictor class was passed, do not mutate predictor
Expand Down
5 changes: 4 additions & 1 deletion src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima

if kwargs.sagemaker_session.settings.include_jumpstart_tags:
kwargs.tags = add_jumpstart_model_id_version_tags(
kwargs.tags, kwargs.model_id, full_model_version
kwargs.tags,
kwargs.model_id,
full_model_version,
config_name=kwargs.config_name,
)
return kwargs

Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:

if kwargs.sagemaker_session.settings.include_jumpstart_tags:
kwargs.tags = add_jumpstart_model_id_version_tags(
kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type
kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type, kwargs.config_name
)

return kwargs
Expand Down
38 changes: 21 additions & 17 deletions src/sagemaker/jumpstart/session_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def get_model_id_version_from_endpoint(
endpoint_name: str,
inference_component_name: Optional[str] = None,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Tuple[str, str, Optional[str]]:
"""Given an endpoint and optionally inference component names, return the model ID and version.
) -> Tuple[str, str, Optional[str], Optional[str]]:
"""Given an endpoint and optionally inference component names, return the model ID, version and config name.

Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
and version. A third string element is included in the tuple for any inferred inference
Expand All @@ -46,6 +46,7 @@ def get_model_id_version_from_endpoint(
(
model_id,
model_version,
config_name,
) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301
inference_component_name, sagemaker_session
)
Expand All @@ -55,21 +56,22 @@ def get_model_id_version_from_endpoint(
model_id,
model_version,
inference_component_name,
config_name,
) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301
endpoint_name, sagemaker_session
)

else:
model_id, model_version = _get_model_id_version_from_model_based_endpoint(
model_id, model_version, config_name = _get_model_id_version_from_model_based_endpoint(
endpoint_name, inference_component_name, sagemaker_session
)
return model_id, model_version, inference_component_name
return model_id, model_version, inference_component_name, config_name


def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name(
endpoint_name: str, sagemaker_session: Session
) -> Tuple[str, str, str]:
"""Given an endpoint name, derives the model ID, version, and inferred inference component name.
"""Given an endpoint name, derives the model ID, version, config name and inferred inference component name.

This function assumes the endpoint corresponds to an inference-component-based endpoint.
An endpoint is inference-component-based if and only if the associated endpoint config
Expand Down Expand Up @@ -123,7 +125,7 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo
f"inference-component/{inference_component_name}"
)

model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn(
inference_component_arn, sagemaker_session
)

Expand All @@ -134,15 +136,15 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo
"when retrieving default predictor for this inference component."
)

return model_id, model_version
return model_id, model_version, config_name


def _get_model_id_version_from_model_based_endpoint(
endpoint_name: str,
inference_component_name: Optional[str],
sagemaker_session: Session,
) -> Tuple[str, str]:
"""Returns the model ID and version inferred from a model-based endpoint.
) -> Tuple[str, str, Optional[str]]:
"""Returns the model ID, version and config name inferred from a model-based endpoint.

Raises:
ValueError: If an inference component name is supplied, or if the endpoint does
Expand All @@ -161,7 +163,7 @@ def _get_model_id_version_from_model_based_endpoint(

endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}"

model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn(
endpoint_arn, sagemaker_session
)

Expand All @@ -172,14 +174,14 @@ def _get_model_id_version_from_model_based_endpoint(
"predictor for this endpoint."
)

return model_id, model_version
return model_id, model_version, config_name


def get_model_id_version_from_training_job(
training_job_name: str,
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Tuple[str, str]:
"""Returns the model ID and version inferred from a training job.
) -> Tuple[str, str, Optional[str]]:
"""Returns the model ID and version and config name inferred from a training job.

Raises:
ValueError: If the training job does not have tags from which the model ID
Expand All @@ -194,9 +196,11 @@ def get_model_id_version_from_training_job(
f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}"
)

model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn(
training_job_arn, sagemaker_session
)
(
model_id,
inferred_model_version,
config_name,
) = get_jumpstart_model_id_version_from_resource_arn(training_job_arn, sagemaker_session)

model_version = inferred_model_version or None

Expand All @@ -207,4 +211,4 @@ def get_model_id_version_from_training_job(
"for this training job."
)

return model_id, model_version
return model_id, model_version, config_name
23 changes: 13 additions & 10 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,9 +1064,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
Dictionary representation of the config component.
"""
for field in json_obj.keys():
if field not in self.__slots__:
raise ValueError(f"Invalid component field: {field}")
setattr(self, field, json_obj[field])
if field in self.__slots__:
setattr(self, field, json_obj[field])


class JumpStartMetadataConfig(JumpStartDataHolderType):
Expand Down Expand Up @@ -1164,20 +1163,18 @@ def get_top_config_from_ranking(
) -> Optional[JumpStartMetadataConfig]:
"""Gets the best the config based on config ranking.

Fallback to use the ordering in config names if
ranking is not available.
Args:
ranking_name (str):
The ranking name that config priority is based on.
instance_type (Optional[str]):
The instance type which the config selection is based on.

Raises:
ValueError: If the config exists but missing config ranking.
NotImplementedError: If the scope is unrecognized.
"""
if self.configs and (
not self.config_rankings or not self.config_rankings.get(ranking_name)
):
raise ValueError(f"Config exists but missing config ranking {ranking_name}.")


if self.scope == JumpStartScriptScope.INFERENCE:
instance_type_attribute = "supported_inference_instance_types"
Expand All @@ -1186,8 +1183,14 @@ def get_top_config_from_ranking(
else:
raise NotImplementedError(f"Unknown script scope {self.scope}")

rankings = self.config_rankings.get(ranking_name)
for config_name in rankings.rankings:
if self.configs and (
not self.config_rankings or not self.config_rankings.get(ranking_name)
):
ranked_config_names = list(self.configs.keys())
else:
rankings = self.config_rankings.get(ranking_name)
ranked_config_names = rankings.rankings
for config_name in ranked_config_names:
resolved_config = self.configs[config_name].resolved_config
if instance_type and instance_type not in getattr(
resolved_config, instance_type_attribute
Expand Down
35 changes: 31 additions & 4 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def get_jumpstart_content_bucket(
for info_log in info_logs:
constants.JUMPSTART_LOGGER.info(info_log)
return bucket_to_return
# return "jumpstart-cache-alpha-us-west-2"


def get_formatted_manifest(
Expand Down Expand Up @@ -318,6 +319,7 @@ def add_single_jumpstart_tag(
tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags)
or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags)
or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags)
or tag_key_in_array(enums.JumpStartTag.MODEL_CONFIG_NAME, curr_tags)
)
if is_uri
else False
Expand Down Expand Up @@ -353,6 +355,7 @@ def add_jumpstart_model_id_version_tags(
model_id: str,
model_version: str,
model_type: Optional[enums.JumpStartModelType] = None,
config_name: Optional[str] = None,
) -> List[TagsDict]:
"""Add custom model ID and version tags to JumpStart related resources."""
if model_id is None or model_version is None:
Expand All @@ -376,6 +379,13 @@ def add_jumpstart_model_id_version_tags(
tags,
is_uri=False,
)
if config_name:
tags = add_single_jumpstart_tag(
config_name,
enums.JumpStartTag.MODEL_CONFIG_NAME,
tags,
is_uri=False,
)
return tags


Expand Down Expand Up @@ -803,19 +813,21 @@ def validate_model_id_and_get_type(
def get_jumpstart_model_id_version_from_resource_arn(
resource_arn: str,
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Tuple[Optional[str], Optional[str]]:
"""Returns the JumpStart model ID and version if in resource tags.
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""Returns the JumpStart model ID, version and config name if in resource tags.

Returns 'None' if model ID or version cannot be inferred from tags.
Returns 'None' if model ID or version or config name cannot be inferred from tags.
"""

list_tags_result = sagemaker_session.list_tags(resource_arn)

model_id: Optional[str] = None
model_version: Optional[str] = None
config_name: Optional[str] = None

model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS]
model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS]
model_config_name_keys = [enums.JumpStartTag.MODEL_CONFIG_NAME]

for model_id_key in model_id_keys:
try:
Expand Down Expand Up @@ -845,7 +857,22 @@ def get_jumpstart_model_id_version_from_resource_arn(
break
model_version = model_version_from_tag

return model_id, model_version
for config_name_key in model_config_name_keys:
try:
config_name_key_from_tag = get_tag_value(config_name_key, list_tags_result)
except KeyError:
continue
if config_name_key_from_tag is not None:
if config_name is not None and config_name_key != config_name:
constants.JUMPSTART_LOGGER.warning(
"Found multiple model config names tags on the following resource: %s",
resource_arn
)
config_name = None
break
config_name = config_name_key_from_tag

return model_id, model_version, config_name


def get_region_fallback(
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def retrieve_default(
inferred_model_id,
inferred_model_version,
inferred_inference_component_name,
inferred_config_name,
) = get_model_id_version_from_endpoint(
endpoint_name, inference_component_name, sagemaker_session
)
Expand All @@ -92,8 +93,10 @@ def retrieve_default(
model_id = inferred_model_id
model_version = model_version or inferred_model_version or "*"
inference_component_name = inference_component_name or inferred_inference_component_name
config_name = inferred_config_name or None
else:
model_version = model_version or "*"
config_name = None

predictor = Predictor(
endpoint_name=endpoint_name,
Expand All @@ -110,4 +113,5 @@ def retrieve_default(
tolerate_vulnerable_model=tolerate_vulnerable_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
config_name=config_name,
)
3 changes: 3 additions & 0 deletions tests/unit/sagemaker/jumpstart/estimator/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case(
get_model_id_version_from_training_job.return_value = (
"js-trainable-model-prepacked",
"1.0.0",
None,
)

mock_get_model_specs.side_effect = get_special_model_spec
Expand Down Expand Up @@ -1212,6 +1213,7 @@ def test_no_predictor_returns_default_predictor(
tolerate_deprecated_model=False,
tolerate_vulnerable_model=False,
sagemaker_session=estimator.sagemaker_session,
config_name=None,
)
self.assertEqual(type(predictor), Predictor)
self.assertEqual(predictor, default_predictor_with_presets)
Expand Down Expand Up @@ -1894,6 +1896,7 @@ def test_estimator_initialization_with_config_name(
tags=[
{"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
{"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
{"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-training"},
],
enable_network_isolation=False,
)
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,7 @@ def test_model_initialization_with_config_name(
tags=[
{"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
{"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
{"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"},
],
wait=True,
endpoint_logging=False,
Expand Down Expand Up @@ -1504,6 +1505,7 @@ def test_model_set_deployment_config(
tags=[
{"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
{"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
{"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"},
],
wait=True,
endpoint_logging=False,
Expand Down Expand Up @@ -1541,6 +1543,7 @@ def test_model_unset_deployment_config(
tags=[
{"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
{"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
{"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"},
],
wait=True,
endpoint_logging=False,
Expand Down
Loading
Loading