diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index ca49fd41a3..0c192084ec 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -92,6 +92,7 @@ class JumpStartTag(str, Enum): MODEL_ID = "sagemaker-sdk:jumpstart-model-id" MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" + MODEL_CONFIG_NAME = "sagemaker-sdk:jumpstart-model-config-name" class SerializerType(str, Enum): diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index cf9b720607..33bb73a83c 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -33,7 +33,7 @@ from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs from sagemaker.jumpstart.factory.model import get_default_predictor -from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job +from sagemaker.jumpstart.session_utils import get_model_info_from_training_job from sagemaker.jumpstart.types import JumpStartMetadataConfig from sagemaker.jumpstart.utils import ( get_jumpstart_configs, @@ -730,10 +730,10 @@ def attach( ValueError: if the model ID or version cannot be inferred from the training job. """ - + config_name = None if model_id is None: - model_id, model_version = get_model_id_version_from_training_job( + model_id, model_version, config_name = get_model_info_from_training_job( training_job_name=training_job_name, sagemaker_session=sagemaker_session ) @@ -749,6 +749,7 @@ def attach( tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated tolerate_vulnerable_model=True, # model is already trained, so tolerate if vulnerable sagemaker_session=sagemaker_session, + config_name=config_name, ) # eula was already accepted if the model was successfully trained @@ -1102,7 +1103,7 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, - # config_name=self.config_name, + config_name=self.config_name, ) # If a predictor class was passed, do not mutate predictor diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 926f313b68..2d5b29b52f 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -478,7 +478,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima if kwargs.sagemaker_session.settings.include_jumpstart_tags: kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version + kwargs.tags, + kwargs.model_id, + full_model_version, + config_name=kwargs.config_name, ) return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 25a1d63215..b4f6d70583 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -496,7 +496,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: if kwargs.sagemaker_session.settings.include_jumpstart_tags: kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type + kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type, kwargs.config_name ) return kwargs diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index e511a052d1..0fa7722f91 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -22,12 +22,12 @@ from sagemaker.utils import aws_partition -def get_model_id_version_from_endpoint( +def get_model_info_from_endpoint( endpoint_name: str, inference_component_name: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str, Optional[str]]: - """Given an endpoint and optionally inference component names, return the model ID and version. +) -> Tuple[str, str, Optional[str], Optional[str]]: + """Optionally inference component names, return the model ID, version and config name. Infers the model ID and version based on the resource tags. Returns a tuple of the model ID and version. A third string element is included in the tuple for any inferred inference @@ -46,7 +46,8 @@ def get_model_id_version_from_endpoint( ( model_id, model_version, - ) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 + config_name, + ) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 inference_component_name, sagemaker_session ) @@ -54,22 +55,23 @@ def get_model_id_version_from_endpoint( ( model_id, model_version, + config_name, inference_component_name, - ) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 + ) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 endpoint_name, sagemaker_session ) else: - model_id, model_version = _get_model_id_version_from_model_based_endpoint( + model_id, model_version, config_name = _get_model_info_from_model_based_endpoint( endpoint_name, inference_component_name, sagemaker_session ) - return model_id, model_version, inference_component_name + return model_id, model_version, inference_component_name, config_name -def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( +def _get_model_info_from_inference_component_endpoint_without_inference_component_name( endpoint_name: str, sagemaker_session: Session -) -> Tuple[str, str, str]: - """Given an endpoint name, derives the model ID, version, and inferred inference component name. +) -> Tuple[str, str, str, str]: + """Derives the model ID, version, config name and inferred inference component name. This function assumes the endpoint corresponds to an inference-component-based endpoint. An endpoint is inference-component-based if and only if the associated endpoint config @@ -98,14 +100,14 @@ def _get_model_id_version_from_inference_component_endpoint_without_inference_co ) inference_component_name = inference_component_names[0] return ( - *_get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + *_get_model_info_from_inference_component_endpoint_with_inference_component_name( inference_component_name, sagemaker_session ), inference_component_name, ) -def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( +def _get_model_info_from_inference_component_endpoint_with_inference_component_name( inference_component_name: str, sagemaker_session: Session ): """Returns the model ID and version inferred from a SageMaker inference component. @@ -123,7 +125,7 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo f"inference-component/{inference_component_name}" ) - model_id, model_version = get_jumpstart_model_id_version_from_resource_arn( + model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn( inference_component_arn, sagemaker_session ) @@ -134,15 +136,15 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo "when retrieving default predictor for this inference component." ) - return model_id, model_version + return model_id, model_version, config_name -def _get_model_id_version_from_model_based_endpoint( +def _get_model_info_from_model_based_endpoint( endpoint_name: str, inference_component_name: Optional[str], sagemaker_session: Session, -) -> Tuple[str, str]: - """Returns the model ID and version inferred from a model-based endpoint. +) -> Tuple[str, str, Optional[str]]: + """Returns the model ID, version and config name inferred from a model-based endpoint. Raises: ValueError: If an inference component name is supplied, or if the endpoint does @@ -161,7 +163,7 @@ def _get_model_id_version_from_model_based_endpoint( endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}" - model_id, model_version = get_jumpstart_model_id_version_from_resource_arn( + model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn( endpoint_arn, sagemaker_session ) @@ -172,14 +174,14 @@ def _get_model_id_version_from_model_based_endpoint( "predictor for this endpoint." ) - return model_id, model_version + return model_id, model_version, config_name -def get_model_id_version_from_training_job( +def get_model_info_from_training_job( training_job_name: str, sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str]: - """Returns the model ID and version inferred from a training job. +) -> Tuple[str, str, Optional[str]]: + """Returns the model ID and version and config name inferred from a training job. Raises: ValueError: If the training job does not have tags from which the model ID @@ -194,9 +196,11 @@ def get_model_id_version_from_training_job( f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}" ) - model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn( - training_job_arn, sagemaker_session - ) + ( + model_id, + inferred_model_version, + config_name, + ) = get_jumpstart_model_id_version_from_resource_arn(training_job_arn, sagemaker_session) model_version = inferred_model_version or None @@ -207,4 +211,4 @@ def get_model_id_version_from_training_job( "for this training job." ) - return model_id, model_version + return model_id, model_version, config_name diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 07bd769054..bf0a84319b 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1064,9 +1064,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Dictionary representation of the config component. """ for field in json_obj.keys(): - if field not in self.__slots__: - raise ValueError(f"Invalid component field: {field}") - setattr(self, field, json_obj[field]) + if field in self.__slots__: + setattr(self, field, json_obj[field]) class JumpStartMetadataConfig(JumpStartDataHolderType): @@ -1164,6 +1163,8 @@ def get_top_config_from_ranking( ) -> Optional[JumpStartMetadataConfig]: """Gets the best the config based on config ranking. + Fallback to use the ordering in config names if + ranking is not available. Args: ranking_name (str): The ranking name that config priority is based on. @@ -1171,13 +1172,8 @@ def get_top_config_from_ranking( The instance type which the config selection is based on. Raises: - ValueError: If the config exists but missing config ranking. NotImplementedError: If the scope is unrecognized. """ - if self.configs and ( - not self.config_rankings or not self.config_rankings.get(ranking_name) - ): - raise ValueError(f"Config exists but missing config ranking {ranking_name}.") if self.scope == JumpStartScriptScope.INFERENCE: instance_type_attribute = "supported_inference_instance_types" @@ -1186,8 +1182,14 @@ def get_top_config_from_ranking( else: raise NotImplementedError(f"Unknown script scope {self.scope}") - rankings = self.config_rankings.get(ranking_name) - for config_name in rankings.rankings: + if self.configs and ( + not self.config_rankings or not self.config_rankings.get(ranking_name) + ): + ranked_config_names = sorted(list(self.configs.keys())) + else: + rankings = self.config_rankings.get(ranking_name) + ranked_config_names = rankings.rankings + for config_name in ranked_config_names: resolved_config = self.configs[config_name].resolved_config if instance_type and instance_type not in getattr( resolved_config, instance_type_attribute diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 905f2a18d5..59bf11b415 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -318,6 +318,7 @@ def add_single_jumpstart_tag( tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags) + or tag_key_in_array(enums.JumpStartTag.MODEL_CONFIG_NAME, curr_tags) ) if is_uri else False @@ -353,6 +354,7 @@ def add_jumpstart_model_id_version_tags( model_id: str, model_version: str, model_type: Optional[enums.JumpStartModelType] = None, + config_name: Optional[str] = None, ) -> List[TagsDict]: """Add custom model ID and version tags to JumpStart related resources.""" if model_id is None or model_version is None: @@ -376,6 +378,13 @@ def add_jumpstart_model_id_version_tags( tags, is_uri=False, ) + if config_name: + tags = add_single_jumpstart_tag( + config_name, + enums.JumpStartTag.MODEL_CONFIG_NAME, + tags, + is_uri=False, + ) return tags @@ -800,52 +809,72 @@ def validate_model_id_and_get_type( return None +def _extract_value_from_list_of_tags( + tag_keys: List[str], + list_tags_result: List[str], + resource_name: str, + resource_arn: str, +): + """Extracts value from list of tags with check of duplicate tags. + + Returns None if no value is found. + """ + resolved_value = None + for tag_key in tag_keys: + try: + value_from_tag = get_tag_value(tag_key, list_tags_result) + except KeyError: + continue + if value_from_tag is not None: + if resolved_value is not None and value_from_tag != resolved_value: + constants.JUMPSTART_LOGGER.warning( + "Found multiple %s tags on the following resource: %s", + resource_name, + resource_arn, + ) + resolved_value = None + break + resolved_value = value_from_tag + return resolved_value + + def get_jumpstart_model_id_version_from_resource_arn( resource_arn: str, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[Optional[str], Optional[str]]: - """Returns the JumpStart model ID and version if in resource tags. +) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """Returns the JumpStart model ID, version and config name if in resource tags. - Returns 'None' if model ID or version cannot be inferred from tags. + Returns 'None' if model ID or version or config name cannot be inferred from tags. """ list_tags_result = sagemaker_session.list_tags(resource_arn) - model_id: Optional[str] = None - model_version: Optional[str] = None - model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS] model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS] + model_config_name_keys = [enums.JumpStartTag.MODEL_CONFIG_NAME] - for model_id_key in model_id_keys: - try: - model_id_from_tag = get_tag_value(model_id_key, list_tags_result) - except KeyError: - continue - if model_id_from_tag is not None: - if model_id is not None and model_id_from_tag != model_id: - constants.JUMPSTART_LOGGER.warning( - "Found multiple model ID tags on the following resource: %s", resource_arn - ) - model_id = None - break - model_id = model_id_from_tag + model_id: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_id_keys, + list_tags_result=list_tags_result, + resource_name="model ID", + resource_arn=resource_arn, + ) - for model_version_key in model_version_keys: - try: - model_version_from_tag = get_tag_value(model_version_key, list_tags_result) - except KeyError: - continue - if model_version_from_tag is not None: - if model_version is not None and model_version_from_tag != model_version: - constants.JUMPSTART_LOGGER.warning( - "Found multiple model version tags on the following resource: %s", resource_arn - ) - model_version = None - break - model_version = model_version_from_tag + model_version: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_version_keys, + list_tags_result=list_tags_result, + resource_name="model version", + resource_arn=resource_arn, + ) + + config_name: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_config_name_keys, + list_tags_result=list_tags_result, + resource_name="model config name", + resource_arn=resource_arn, + ) - return model_id, model_version + return model_id, model_version, config_name def get_region_fallback( diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 6f846bba65..14e2ae278b 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -18,7 +18,7 @@ from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.factory.model import get_default_predictor -from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint +from sagemaker.jumpstart.session_utils import get_model_info_from_endpoint from sagemaker.session import Session @@ -78,9 +78,8 @@ def retrieve_default( inferred_model_id, inferred_model_version, inferred_inference_component_name, - ) = get_model_id_version_from_endpoint( - endpoint_name, inference_component_name, sagemaker_session - ) + inferred_config_name, + ) = get_model_info_from_endpoint(endpoint_name, inference_component_name, sagemaker_session) if not inferred_model_id: raise ValueError( @@ -92,8 +91,10 @@ def retrieve_default( model_id = inferred_model_id model_version = model_version or inferred_model_version or "*" inference_component_name = inference_component_name or inferred_inference_component_name + config_name = inferred_config_name or None else: model_version = model_version or "*" + config_name = None predictor = Predictor( endpoint_name=endpoint_name, @@ -110,4 +111,5 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index ae79bb8b55..1ed860fc05 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1016,7 +1016,7 @@ def test_jumpstart_estimator_attach_eula_model( ) @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") - @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1024,15 +1024,16 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, - get_model_id_version_from_training_job: mock.Mock, + get_model_info_from_training_job: mock.Mock, mock_attach: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - get_model_id_version_from_training_job.return_value = ( + get_model_info_from_training_job.return_value = ( "js-trainable-model-prepacked", "1.0.0", + None, ) mock_get_model_specs.side_effect = get_special_model_spec @@ -1043,7 +1044,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( training_job_name="some-training-job-name", sagemaker_session=mock_session ) - get_model_id_version_from_training_job.assert_called_once_with( + get_model_info_from_training_job.assert_called_once_with( training_job_name="some-training-job-name", sagemaker_session=mock_session, ) @@ -1059,7 +1060,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( ) @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") - @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1067,13 +1068,13 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, - get_model_id_version_from_training_job: mock.Mock, + get_model_info_from_training_job: mock.Mock, mock_attach: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - get_model_id_version_from_training_job.side_effect = ValueError() + get_model_info_from_training_job.side_effect = ValueError() mock_get_model_specs.side_effect = get_special_model_spec @@ -1084,7 +1085,7 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( training_job_name="some-training-job-name", sagemaker_session=mock_session ) - get_model_id_version_from_training_job.assert_called_once_with( + get_model_info_from_training_job.assert_called_once_with( training_job_name="some-training-job-name", sagemaker_session=mock_session, ) @@ -1212,6 +1213,7 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=estimator.sagemaker_session, + config_name=None, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -1894,6 +1896,7 @@ def test_estimator_initialization_with_config_name( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-training"}, ], enable_network_isolation=False, ) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 2df904dce2..476002457b 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1559,6 +1559,7 @@ def test_model_initialization_with_config_name( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, @@ -1618,6 +1619,7 @@ def test_model_set_deployment_config( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, @@ -1659,6 +1661,7 @@ def test_model_unset_deployment_config( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index abce2bf687..1cc8f292f0 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -18,7 +18,7 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec, get_spec_from_base_spec -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support( @@ -52,7 +52,7 @@ def test_jumpstart_predictor_support( assert js_predictor.accept == MIMEType.JSON -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_proprietary_predictor_support( @@ -91,7 +91,7 @@ def test_proprietary_predictor_support( @patch("sagemaker.predictor.Predictor") @patch("sagemaker.predictor.get_default_predictor") -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( @@ -109,6 +109,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( "predictor-specs-model", "1.2.3", None, + None, ) mock_session = Mock() @@ -128,11 +129,12 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_vulnerable_model=False, sagemaker_session=mock_session, model_type=JumpStartModelType.OPEN_WEIGHTS, + config_name=None, ) @patch("sagemaker.predictor.get_default_predictor") -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( @@ -159,10 +161,8 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( patched_get_default_predictor.assert_not_called() -@mock.patch( - "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} -) -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}) +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") @patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @@ -172,7 +172,7 @@ def test_jumpstart_serializable_payload_with_predictor( patched_verify_model_region_and_return_specs, patched_validate_model_id_and_get_type, patched_get_object_cached, - patched_get_model_id_version_from_endpoint, + patched_get_model_info_from_endpoint, patched_get_jumpstart_configs, ): @@ -183,7 +183,7 @@ def test_jumpstart_serializable_payload_with_predictor( patched_get_model_specs.side_effect = get_special_model_spec model_id, model_version = "default_payloads", "*" - patched_get_model_id_version_from_endpoint.return_value = model_id, model_version, None + patched_get_model_info_from_endpoint.return_value = model_id, model_version, None js_predictor = predictor.retrieve_default( endpoint_name="blah", model_id=model_id, model_version=model_version diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py index 76ad50f31c..9dc8acb32a 100644 --- a/tests/unit/sagemaker/jumpstart/test_session_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py @@ -4,16 +4,16 @@ import pytest from sagemaker.jumpstart.session_utils import ( - _get_model_id_version_from_inference_component_endpoint_with_inference_component_name, - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name, - _get_model_id_version_from_model_based_endpoint, - get_model_id_version_from_endpoint, - get_model_id_version_from_training_job, + _get_model_info_from_inference_component_endpoint_with_inference_component_name, + _get_model_info_from_inference_component_endpoint_without_inference_component_name, + _get_model_info_from_model_based_endpoint, + get_model_info_from_endpoint, + get_model_info_from_training_job, ) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_training_job_happy_case( +def test_get_model_info_from_training_job_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -23,11 +23,35 @@ def test_get_model_id_version_from_training_job_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( "model_id", "model_version", + None, + ) + + retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", None) + + mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session + ) + + +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +def test_get_model_info_from_training_job_config_name( + mock_get_jumpstart_model_id_version_from_resource_arn, +): + mock_sm_session = Mock() + mock_sm_session.boto_region_name = "us-west-2" + mock_sm_session.account_id = Mock(return_value="123456789012") + + mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + "model_id", + "model_version", + "config_name", ) - retval = get_model_id_version_from_training_job("bLaH", sagemaker_session=mock_sm_session) + retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", "config_name") mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session @@ -35,7 +59,7 @@ def test_get_model_id_version_from_training_job_happy_case( @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_training_job_no_model_id_inferred( +def test_get_model_info_from_training_job_no_model_id_inferred( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -48,11 +72,11 @@ def test_get_model_id_version_from_training_job_no_model_id_inferred( ) with pytest.raises(ValueError): - get_model_id_version_from_training_job("blah", sagemaker_session=mock_sm_session) + get_model_info_from_training_job("blah", sagemaker_session=mock_sm_session) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_happy_case( +def test_get_model_info_from_model_based_endpoint_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -62,13 +86,14 @@ def test_get_model_id_version_from_model_based_endpoint_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( "model_id", "model_version", + None, ) - retval = _get_model_id_version_from_model_based_endpoint( + retval = _get_model_info_from_model_based_endpoint( "bLaH", inference_component_name=None, sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None) mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:endpoint/blah", mock_sm_session @@ -76,7 +101,7 @@ def test_get_model_id_version_from_model_based_endpoint_happy_case( @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_inference_component_supplied( +def test_get_model_info_from_model_based_endpoint_inference_component_supplied( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -86,16 +111,17 @@ def test_get_model_id_version_from_model_based_endpoint_inference_component_supp mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( "model_id", "model_version", + None, ) with pytest.raises(ValueError): - _get_model_id_version_from_model_based_endpoint( + _get_model_info_from_model_based_endpoint( "blah", inference_component_name="some-name", sagemaker_session=mock_sm_session ) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_no_model_id_inferred( +def test_get_model_info_from_model_based_endpoint_no_model_id_inferred( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -108,13 +134,13 @@ def test_get_model_id_version_from_model_based_endpoint_no_model_id_inferred( ) with pytest.raises(ValueError): - _get_model_id_version_from_model_based_endpoint( + _get_model_info_from_model_based_endpoint( "blah", inference_component_name="some-name", sagemaker_session=mock_sm_session ) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_inference_component_endpoint_with_inference_component_name_happy_case( +def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -124,13 +150,14 @@ def test_get_model_id_version_from_inference_component_endpoint_with_inference_c mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( "model_id", "model_version", + None, ) - retval = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + retval = _get_model_info_from_inference_component_endpoint_with_inference_component_name( "bLaH", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None) mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:inference-component/bLaH", mock_sm_session @@ -138,7 +165,7 @@ def test_get_model_id_version_from_inference_component_endpoint_with_inference_c @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( +def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -148,23 +175,24 @@ def test_get_model_id_version_from_inference_component_endpoint_with_inference_c mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( None, None, + None, ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + _get_model_info_from_inference_component_endpoint_with_inference_component_name( "blah", sagemaker_session=mock_sm_session ) @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_inference_component_endpoint_without_inference_component_name_happy_case( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_inference_component_endpoint_without_inference_component_name_happy_case( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -172,10 +200,8 @@ def test_get_model_id_version_from_inference_component_endpoint_without_inferenc return_value=["icname"] ) - retval = ( - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( - "blahblah", mock_sm_session - ) + retval = _get_model_info_from_inference_component_endpoint_without_inference_component_name( + "blahblah", mock_sm_session ) assert retval == ("model_id", "model_version", "icname") @@ -185,14 +211,14 @@ def test_get_model_id_version_from_inference_component_endpoint_without_inferenc @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -200,7 +226,7 @@ def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_ return_value=[] ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( + _get_model_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) @@ -210,14 +236,14 @@ def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_ @patch( - "sagemaker.jumpstart.session_utils._get_model_id" - "_version_from_inference_component_endpoint_with_inference_component_name" + "sagemaker.jumpstart.session_utils._get_model" + "_info_from_inference_component_endpoint_with_inference_component_name" ) def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_multiple_ics_for_endpoint( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -227,7 +253,7 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_ ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( + _get_model_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) @@ -236,67 +262,92 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_ ) -@patch("sagemaker.jumpstart.session_utils._get_model_id_version_from_model_based_endpoint") -def test_get_model_id_version_from_endpoint_non_inference_component_endpoint( - mock_get_model_id_version_from_model_based_endpoint, +@patch("sagemaker.jumpstart.session_utils._get_model_info_from_model_based_endpoint") +def test_get_model_info_from_endpoint_non_inference_component_endpoint( + mock_get_model_info_from_model_based_endpoint, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = False - mock_get_model_id_version_from_model_based_endpoint.return_value = ( + mock_get_model_info_from_model_based_endpoint.return_value = ( "model_id", "model_version", + None, ) - retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", None) - mock_get_model_id_version_from_model_based_endpoint.assert_called_once_with( + assert retval == ("model_id", "model_version", None, None) + mock_get_model_info_from_model_based_endpoint.assert_called_once_with( "blah", None, mock_sm_session ) mock_sm_session.is_inference_component_based_endpoint.assert_called_once_with("blah") @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_endpoint_inference_component_endpoint_with_inference_component_name( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_endpoint_inference_component_endpoint_with_inference_component_name( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", + None, ) - retval = get_model_id_version_from_endpoint( + retval = get_model_info_from_endpoint( "blah", inference_component_name="icname", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version", "icname") - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( + assert retval == ("model_id", "model_version", "icname", None) + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( "icname", mock_sm_session ) mock_sm_session.is_inference_component_based_endpoint.assert_not_called() @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_component_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" + "endpoint_without_inference_component_name" +) +def test_get_model_info_from_endpoint_inference_component_endpoint_without_inference_component_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, +): + mock_sm_session = Mock() + mock_sm_session.is_inference_component_based_endpoint.return_value = True + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( + "model_id", + "model_version", + None, + "inferred-icname", + ) + + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", "inferred-icname", None) + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + + +@patch( + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" "endpoint_without_inference_component_name" ) -def test_get_model_id_version_from_endpoint_inference_component_endpoint_without_inference_component_name( - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name, +def test_get_model_info_from_endpoint_inference_component_endpoint_with_config_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( "model_id", "model_version", + "config_name", "inferred-icname", ) - retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", "inferred-icname") - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + assert retval == ("model_id", "model_version", "inferred-icname", "config_name") + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 85911a2854..83724e5e8a 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1323,7 +1323,7 @@ def test_no_model_id_no_version_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - (None, None), + (None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1340,7 +1340,7 @@ def test_model_id_no_version_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - ("model_id", None), + ("model_id", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1357,7 +1357,38 @@ def test_no_model_id_version_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - (None, "model_version"), + (None, "model_version", None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_no_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] + + self.assertEquals( + utils.get_jumpstart_model_id_version_from_resource_arn( + "some-arn", mock_sagemaker_session + ), + (None, None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_id_version_from_resource_arn( + "some-arn", mock_sagemaker_session + ), + (None, None, "config_name"), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1375,7 +1406,7 @@ def test_model_id_version_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - ("model_id", "model_version"), + ("model_id", "model_version", None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1395,7 +1426,7 @@ def test_multiple_model_id_versions_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - (None, None), + (None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1415,7 +1446,7 @@ def test_multiple_model_id_versions_found_aliases_consistent(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - ("model_id_1", "model_version_1"), + ("model_id_1", "model_version_1", None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1435,7 +1466,27 @@ def test_multiple_model_id_versions_found_aliases_inconsistent(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - (None, None), + (None, None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_multiple_config_names_found_aliases_inconsistent(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.MODEL_ID, "Value": "model_id_1"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version_1"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name_1"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name_2"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_id_version_from_resource_arn( + "some-arn", mock_sagemaker_session + ), + ("model_id_1", "model_version_1", None), ) mock_list_tags.assert_called_once_with("some-arn")