Skip to content

feat: tag JumpStart resource with config names #4608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class JumpStartTag(str, Enum):
MODEL_ID = "sagemaker-sdk:jumpstart-model-id"
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type"
MODEL_CONFIG_NAME = "sagemaker-sdk:jumpstart-model-config-name"


class SerializerType(str, Enum):
Expand Down
9 changes: 5 additions & 4 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs
from sagemaker.jumpstart.factory.model import get_default_predictor
from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job
from sagemaker.jumpstart.session_utils import get_model_info_from_training_job
from sagemaker.jumpstart.types import JumpStartMetadataConfig
from sagemaker.jumpstart.utils import (
get_jumpstart_configs,
Expand Down Expand Up @@ -730,10 +730,10 @@ def attach(
ValueError: if the model ID or version cannot be inferred from the training job.

"""

config_name = None
if model_id is None:

model_id, model_version = get_model_id_version_from_training_job(
model_id, model_version, config_name = get_model_info_from_training_job(
training_job_name=training_job_name, sagemaker_session=sagemaker_session
)

Expand All @@ -749,6 +749,7 @@ def attach(
tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated
tolerate_vulnerable_model=True, # model is already trained, so tolerate if vulnerable
sagemaker_session=sagemaker_session,
config_name=config_name,
)

# eula was already accepted if the model was successfully trained
Expand Down Expand Up @@ -1102,7 +1103,7 @@ def deploy(
tolerate_deprecated_model=self.tolerate_deprecated_model,
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
sagemaker_session=self.sagemaker_session,
# config_name=self.config_name,
config_name=self.config_name,
)

# If a predictor class was passed, do not mutate predictor
Expand Down
5 changes: 4 additions & 1 deletion src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima

if kwargs.sagemaker_session.settings.include_jumpstart_tags:
kwargs.tags = add_jumpstart_model_id_version_tags(
kwargs.tags, kwargs.model_id, full_model_version
kwargs.tags,
kwargs.model_id,
full_model_version,
config_name=kwargs.config_name,
)
return kwargs

Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:

if kwargs.sagemaker_session.settings.include_jumpstart_tags:
kwargs.tags = add_jumpstart_model_id_version_tags(
kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type
kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type, kwargs.config_name
)

return kwargs
Expand Down
56 changes: 30 additions & 26 deletions src/sagemaker/jumpstart/session_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
from sagemaker.utils import aws_partition


def get_model_id_version_from_endpoint(
def get_model_info_from_endpoint(
endpoint_name: str,
inference_component_name: Optional[str] = None,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Tuple[str, str, Optional[str]]:
"""Given an endpoint and optionally inference component names, return the model ID and version.
) -> Tuple[str, str, Optional[str], Optional[str]]:
"""Optionally inference component names, return the model ID, version and config name.

Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
and version. A third string element is included in the tuple for any inferred inference
Expand All @@ -46,30 +46,32 @@ def get_model_id_version_from_endpoint(
(
model_id,
model_version,
) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301
config_name,
) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301
inference_component_name, sagemaker_session
)

else:
(
model_id,
model_version,
config_name,
inference_component_name,
) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301
) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301
endpoint_name, sagemaker_session
)

else:
model_id, model_version = _get_model_id_version_from_model_based_endpoint(
model_id, model_version, config_name = _get_model_info_from_model_based_endpoint(
endpoint_name, inference_component_name, sagemaker_session
)
return model_id, model_version, inference_component_name
return model_id, model_version, inference_component_name, config_name


def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name(
def _get_model_info_from_inference_component_endpoint_without_inference_component_name(
endpoint_name: str, sagemaker_session: Session
) -> Tuple[str, str, str]:
"""Given an endpoint name, derives the model ID, version, and inferred inference component name.
) -> Tuple[str, str, str, str]:
"""Derives the model ID, version, config name and inferred inference component name.

This function assumes the endpoint corresponds to an inference-component-based endpoint.
An endpoint is inference-component-based if and only if the associated endpoint config
Expand Down Expand Up @@ -98,14 +100,14 @@ def _get_model_id_version_from_inference_component_endpoint_without_inference_co
)
inference_component_name = inference_component_names[0]
return (
*_get_model_id_version_from_inference_component_endpoint_with_inference_component_name(
*_get_model_info_from_inference_component_endpoint_with_inference_component_name(
inference_component_name, sagemaker_session
),
inference_component_name,
)


def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name(
def _get_model_info_from_inference_component_endpoint_with_inference_component_name(
inference_component_name: str, sagemaker_session: Session
):
"""Returns the model ID and version inferred from a SageMaker inference component.
Expand All @@ -123,7 +125,7 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo
f"inference-component/{inference_component_name}"
)

model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn(
inference_component_arn, sagemaker_session
)

Expand All @@ -134,15 +136,15 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo
"when retrieving default predictor for this inference component."
)

return model_id, model_version
return model_id, model_version, config_name


def _get_model_id_version_from_model_based_endpoint(
def _get_model_info_from_model_based_endpoint(
endpoint_name: str,
inference_component_name: Optional[str],
sagemaker_session: Session,
) -> Tuple[str, str]:
"""Returns the model ID and version inferred from a model-based endpoint.
) -> Tuple[str, str, Optional[str]]:
"""Returns the model ID, version and config name inferred from a model-based endpoint.

Raises:
ValueError: If an inference component name is supplied, or if the endpoint does
Expand All @@ -161,7 +163,7 @@ def _get_model_id_version_from_model_based_endpoint(

endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}"

model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn(
endpoint_arn, sagemaker_session
)

Expand All @@ -172,14 +174,14 @@ def _get_model_id_version_from_model_based_endpoint(
"predictor for this endpoint."
)

return model_id, model_version
return model_id, model_version, config_name


def get_model_id_version_from_training_job(
def get_model_info_from_training_job(
training_job_name: str,
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Tuple[str, str]:
"""Returns the model ID and version inferred from a training job.
) -> Tuple[str, str, Optional[str]]:
"""Returns the model ID and version and config name inferred from a training job.

Raises:
ValueError: If the training job does not have tags from which the model ID
Expand All @@ -194,9 +196,11 @@ def get_model_id_version_from_training_job(
f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}"
)

model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn(
training_job_arn, sagemaker_session
)
(
model_id,
inferred_model_version,
config_name,
) = get_jumpstart_model_id_version_from_resource_arn(training_job_arn, sagemaker_session)

model_version = inferred_model_version or None

Expand All @@ -207,4 +211,4 @@ def get_model_id_version_from_training_job(
"for this training job."
)

return model_id, model_version
return model_id, model_version, config_name
22 changes: 12 additions & 10 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,9 +1064,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
Dictionary representation of the config component.
"""
for field in json_obj.keys():
if field not in self.__slots__:
raise ValueError(f"Invalid component field: {field}")
setattr(self, field, json_obj[field])
if field in self.__slots__:
setattr(self, field, json_obj[field])


class JumpStartMetadataConfig(JumpStartDataHolderType):
Expand Down Expand Up @@ -1164,20 +1163,17 @@ def get_top_config_from_ranking(
) -> Optional[JumpStartMetadataConfig]:
"""Gets the best the config based on config ranking.

Fallback to use the ordering in config names if
ranking is not available.
Args:
ranking_name (str):
The ranking name that config priority is based on.
instance_type (Optional[str]):
The instance type which the config selection is based on.

Raises:
ValueError: If the config exists but missing config ranking.
NotImplementedError: If the scope is unrecognized.
"""
if self.configs and (
not self.config_rankings or not self.config_rankings.get(ranking_name)
):
raise ValueError(f"Config exists but missing config ranking {ranking_name}.")

if self.scope == JumpStartScriptScope.INFERENCE:
instance_type_attribute = "supported_inference_instance_types"
Expand All @@ -1186,8 +1182,14 @@ def get_top_config_from_ranking(
else:
raise NotImplementedError(f"Unknown script scope {self.scope}")

rankings = self.config_rankings.get(ranking_name)
for config_name in rankings.rankings:
if self.configs and (
not self.config_rankings or not self.config_rankings.get(ranking_name)
):
ranked_config_names = sorted(list(self.configs.keys()))
else:
rankings = self.config_rankings.get(ranking_name)
ranked_config_names = rankings.rankings
for config_name in ranked_config_names:
resolved_config = self.configs[config_name].resolved_config
if instance_type and instance_type not in getattr(
resolved_config, instance_type_attribute
Expand Down
Loading