From 49c828ac5777b00049d7ddb28ed3251a4c50dd78 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Tue, 23 Apr 2024 21:29:07 +0000 Subject: [PATCH 01/10] tag config name --- src/sagemaker/jumpstart/enums.py | 1 + src/sagemaker/jumpstart/estimator.py | 9 +-- src/sagemaker/jumpstart/factory/estimator.py | 2 +- src/sagemaker/jumpstart/factory/model.py | 2 +- src/sagemaker/jumpstart/session_utils.py | 32 ++++----- src/sagemaker/jumpstart/types.py | 25 +++---- src/sagemaker/jumpstart/utils.py | 34 ++++++++-- src/sagemaker/predictor.py | 4 ++ .../jumpstart/estimator/test_estimator.py | 3 + .../sagemaker/jumpstart/model/test_model.py | 3 + .../sagemaker/jumpstart/test_predictor.py | 2 + .../sagemaker/jumpstart/test_session_utils.py | 65 +++++++++++++++++-- tests/unit/sagemaker/jumpstart/test_utils.py | 65 +++++++++++++++++-- 13 files changed, 198 insertions(+), 49 deletions(-) diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index ca49fd41a3..0c192084ec 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -92,6 +92,7 @@ class JumpStartTag(str, Enum): MODEL_ID = "sagemaker-sdk:jumpstart-model-id" MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" + MODEL_CONFIG_NAME = "sagemaker-sdk:jumpstart-model-config-name" class SerializerType(str, Enum): diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index cf9b720607..8ae9a11172 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -730,13 +730,13 @@ def attach( ValueError: if the model ID or version cannot be inferred from the training job. """ - + config_name = None if model_id is None: - model_id, model_version = get_model_id_version_from_training_job( + model_id, model_version, config_name = get_model_id_version_from_training_job( training_job_name=training_job_name, sagemaker_session=sagemaker_session ) - + model_version = model_version or "*" additional_kwargs = {"model_id": model_id, "model_version": model_version} @@ -749,6 +749,7 @@ def attach( tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated tolerate_vulnerable_model=True, # model is already trained, so tolerate if vulnerable sagemaker_session=sagemaker_session, + config_name=config_name, ) # eula was already accepted if the model was successfully trained @@ -1102,7 +1103,7 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, - # config_name=self.config_name, + config_name=self.config_name, ) # If a predictor class was passed, do not mutate predictor diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 926f313b68..6f277e33d1 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -478,7 +478,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima if kwargs.sagemaker_session.settings.include_jumpstart_tags: kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version + kwargs.tags, kwargs.model_id, full_model_version, config_name=kwargs.config_name, ) return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 25a1d63215..b4f6d70583 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -496,7 +496,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: if kwargs.sagemaker_session.settings.include_jumpstart_tags: kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type + kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type, kwargs.config_name ) return kwargs diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index e511a052d1..697fe63183 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -26,8 +26,8 @@ def get_model_id_version_from_endpoint( endpoint_name: str, inference_component_name: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str, Optional[str]]: - """Given an endpoint and optionally inference component names, return the model ID and version. +) -> Tuple[str, str, Optional[str], Optional[str]]: + """Given an endpoint and optionally inference component names, return the model ID, version and config name. Infers the model ID and version based on the resource tags. Returns a tuple of the model ID and version. A third string element is included in the tuple for any inferred inference @@ -46,6 +46,7 @@ def get_model_id_version_from_endpoint( ( model_id, model_version, + config_name, ) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 inference_component_name, sagemaker_session ) @@ -55,21 +56,22 @@ def get_model_id_version_from_endpoint( model_id, model_version, inference_component_name, + config_name, ) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 endpoint_name, sagemaker_session ) else: - model_id, model_version = _get_model_id_version_from_model_based_endpoint( + model_id, model_version, config_name = _get_model_id_version_from_model_based_endpoint( endpoint_name, inference_component_name, sagemaker_session ) - return model_id, model_version, inference_component_name + return model_id, model_version, inference_component_name, config_name def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( endpoint_name: str, sagemaker_session: Session ) -> Tuple[str, str, str]: - """Given an endpoint name, derives the model ID, version, and inferred inference component name. + """Given an endpoint name, derives the model ID, version, config name and inferred inference component name. This function assumes the endpoint corresponds to an inference-component-based endpoint. An endpoint is inference-component-based if and only if the associated endpoint config @@ -123,7 +125,7 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo f"inference-component/{inference_component_name}" ) - model_id, model_version = get_jumpstart_model_id_version_from_resource_arn( + model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn( inference_component_arn, sagemaker_session ) @@ -134,15 +136,15 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo "when retrieving default predictor for this inference component." ) - return model_id, model_version + return model_id, model_version, config_name def _get_model_id_version_from_model_based_endpoint( endpoint_name: str, inference_component_name: Optional[str], sagemaker_session: Session, -) -> Tuple[str, str]: - """Returns the model ID and version inferred from a model-based endpoint. +) -> Tuple[str, str, Optional[str]]: + """Returns the model ID, version and config name inferred from a model-based endpoint. Raises: ValueError: If an inference component name is supplied, or if the endpoint does @@ -161,7 +163,7 @@ def _get_model_id_version_from_model_based_endpoint( endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}" - model_id, model_version = get_jumpstart_model_id_version_from_resource_arn( + model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn( endpoint_arn, sagemaker_session ) @@ -172,14 +174,14 @@ def _get_model_id_version_from_model_based_endpoint( "predictor for this endpoint." ) - return model_id, model_version + return model_id, model_version, config_name def get_model_id_version_from_training_job( training_job_name: str, sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str]: - """Returns the model ID and version inferred from a training job. +) -> Tuple[str, str, Optional[str]]: + """Returns the model ID and version and config name inferred from a training job. Raises: ValueError: If the training job does not have tags from which the model ID @@ -194,7 +196,7 @@ def get_model_id_version_from_training_job( f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}" ) - model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn( + model_id, inferred_model_version, config_name = get_jumpstart_model_id_version_from_resource_arn( training_job_arn, sagemaker_session ) @@ -207,4 +209,4 @@ def get_model_id_version_from_training_job( "for this training job." ) - return model_id, model_version + return model_id, model_version, config_name diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 1de0f662da..68f70d4536 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1064,9 +1064,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Dictionary representation of the config component. """ for field in json_obj.keys(): - if field not in self.__slots__: - raise ValueError(f"Invalid component field: {field}") - setattr(self, field, json_obj[field]) + if field in self.__slots__: + setattr(self, field, json_obj[field]) class JumpStartMetadataConfig(JumpStartDataHolderType): @@ -1163,7 +1162,9 @@ def get_top_config_from_ranking( instance_type: Optional[str] = None, ) -> Optional[JumpStartMetadataConfig]: """Gets the best the config based on config ranking. - + + Fallback to use the ordering in config names if + ranking is not available. Args: ranking_name (str): The ranking name that config priority is based on. @@ -1171,13 +1172,9 @@ def get_top_config_from_ranking( The instance type which the config selection is based on. Raises: - ValueError: If the config exists but missing config ranking. NotImplementedError: If the scope is unrecognized. """ - if self.configs and ( - not self.config_rankings or not self.config_rankings.get(ranking_name) - ): - raise ValueError(f"Config exists but missing config ranking {ranking_name}.") + if self.scope == JumpStartScriptScope.INFERENCE: instance_type_attribute = "supported_inference_instance_types" @@ -1186,8 +1183,14 @@ def get_top_config_from_ranking( else: raise NotImplementedError(f"Unknown script scope {self.scope}") - rankings = self.config_rankings.get(ranking_name) - for config_name in rankings.rankings: + if self.configs and ( + not self.config_rankings or not self.config_rankings.get(ranking_name) + ): + ranked_config_names = list(self.configs.keys()) + else: + rankings = self.config_rankings.get(ranking_name) + ranked_config_names = rankings.rankings + for config_name in ranked_config_names: resolved_config = self.configs[config_name].resolved_config if instance_type and instance_type not in getattr( resolved_config, instance_type_attribute diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 1459594faa..cdc1c41630 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -162,6 +162,7 @@ def get_jumpstart_content_bucket( for info_log in info_logs: constants.JUMPSTART_LOGGER.info(info_log) return bucket_to_return + # return "jumpstart-cache-alpha-us-west-2" def get_formatted_manifest( @@ -318,6 +319,7 @@ def add_single_jumpstart_tag( tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags) + or tag_key_in_array(enums.JumpStartTag.MODEL_CONFIG_NAME, curr_tags) ) if is_uri else False @@ -353,6 +355,7 @@ def add_jumpstart_model_id_version_tags( model_id: str, model_version: str, model_type: Optional[enums.JumpStartModelType] = None, + config_name: Optional[str] = None, ) -> List[TagsDict]: """Add custom model ID and version tags to JumpStart related resources.""" if model_id is None or model_version is None: @@ -376,6 +379,13 @@ def add_jumpstart_model_id_version_tags( tags, is_uri=False, ) + if config_name: + tags = add_single_jumpstart_tag( + config_name, + enums.JumpStartTag.MODEL_CONFIG_NAME, + tags, + is_uri=False, + ) return tags @@ -803,19 +813,21 @@ def validate_model_id_and_get_type( def get_jumpstart_model_id_version_from_resource_arn( resource_arn: str, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[Optional[str], Optional[str]]: - """Returns the JumpStart model ID and version if in resource tags. +) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """Returns the JumpStart model ID, version and config name if in resource tags. - Returns 'None' if model ID or version cannot be inferred from tags. + Returns 'None' if model ID or version or config name cannot be inferred from tags. """ list_tags_result = sagemaker_session.list_tags(resource_arn) model_id: Optional[str] = None model_version: Optional[str] = None + config_name: Optional[str] = None model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS] model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS] + model_config_name_keys = [enums.JumpStartTag.MODEL_CONFIG_NAME] for model_id_key in model_id_keys: try: @@ -845,7 +857,21 @@ def get_jumpstart_model_id_version_from_resource_arn( break model_version = model_version_from_tag - return model_id, model_version + for config_name_key in model_config_name_keys: + try: + config_name_key_from_tag = get_tag_value(config_name_key, list_tags_result) + except KeyError: + continue + if config_name_key_from_tag is not None: + if config_name is not None and config_name_key != config_name: + constants.JUMPSTART_LOGGER.warning( + "Found multiple model config names tags on the following resource: %s", resource_arn + ) + config_name = None + break + config_name = config_name_key_from_tag + + return model_id, model_version, config_name def get_region_fallback( diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 6f846bba65..277d8de830 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -78,6 +78,7 @@ def retrieve_default( inferred_model_id, inferred_model_version, inferred_inference_component_name, + inferred_config_name, ) = get_model_id_version_from_endpoint( endpoint_name, inference_component_name, sagemaker_session ) @@ -92,8 +93,10 @@ def retrieve_default( model_id = inferred_model_id model_version = model_version or inferred_model_version or "*" inference_component_name = inference_component_name or inferred_inference_component_name + config_name = inferred_config_name or None else: model_version = model_version or "*" + config_name = None predictor = Predictor( endpoint_name=endpoint_name, @@ -110,4 +113,5 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index ae79bb8b55..34478bae85 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1033,6 +1033,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( get_model_id_version_from_training_job.return_value = ( "js-trainable-model-prepacked", "1.0.0", + None ) mock_get_model_specs.side_effect = get_special_model_spec @@ -1212,6 +1213,7 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=estimator.sagemaker_session, + config_name=None, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -1894,6 +1896,7 @@ def test_estimator_initialization_with_config_name( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-training"}, ], enable_network_isolation=False, ) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index cb7b602fbf..3f7a17a748 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1449,6 +1449,7 @@ def test_model_initialization_with_config_name( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, @@ -1504,6 +1505,7 @@ def test_model_set_deployment_config( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, @@ -1541,6 +1543,7 @@ def test_model_unset_deployment_config( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 52f28f2da1..7cf049ead2 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -109,6 +109,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( "predictor-specs-model", "1.2.3", None, + None, ) mock_session = Mock() @@ -128,6 +129,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_vulnerable_model=False, sagemaker_session=mock_session, model_type=JumpStartModelType.OPEN_WEIGHTS, + config_name=None, ) diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py index 76ad50f31c..c37998cecd 100644 --- a/tests/unit/sagemaker/jumpstart/test_session_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py @@ -23,11 +23,35 @@ def test_get_model_id_version_from_training_job_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( "model_id", "model_version", + None ) retval = get_model_id_version_from_training_job("bLaH", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None) + + mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session + ) + + +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +def test_get_model_id_version_from_training_job_config_name( + mock_get_jumpstart_model_id_version_from_resource_arn, +): + mock_sm_session = Mock() + mock_sm_session.boto_region_name = "us-west-2" + mock_sm_session.account_id = Mock(return_value="123456789012") + + mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + "model_id", + "model_version", + "config_name" + ) + + retval = get_model_id_version_from_training_job("bLaH", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", "config_name") mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session @@ -62,13 +86,14 @@ def test_get_model_id_version_from_model_based_endpoint_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( "model_id", "model_version", + None, ) retval = _get_model_id_version_from_model_based_endpoint( "bLaH", inference_component_name=None, sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None) mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:endpoint/blah", mock_sm_session @@ -86,6 +111,7 @@ def test_get_model_id_version_from_model_based_endpoint_inference_component_supp mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( "model_id", "model_version", + None, ) with pytest.raises(ValueError): @@ -124,13 +150,14 @@ def test_get_model_id_version_from_inference_component_endpoint_with_inference_c mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( "model_id", "model_version", + None, ) retval = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( "bLaH", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None) mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:inference-component/bLaH", mock_sm_session @@ -148,6 +175,7 @@ def test_get_model_id_version_from_inference_component_endpoint_with_inference_c mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( None, None, + None, ) with pytest.raises(ValueError): @@ -245,11 +273,12 @@ def test_get_model_id_version_from_endpoint_non_inference_component_endpoint( mock_get_model_id_version_from_model_based_endpoint.return_value = ( "model_id", "model_version", + None, ) retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", None) + assert retval == ("model_id", "model_version", None, None) mock_get_model_id_version_from_model_based_endpoint.assert_called_once_with( "blah", None, mock_sm_session ) @@ -268,13 +297,14 @@ def test_get_model_id_version_from_endpoint_inference_component_endpoint_with_in mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", + None ) retval = get_model_id_version_from_endpoint( "blah", inference_component_name="icname", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version", "icname") + assert retval == ("model_id", "model_version", "icname", None) mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( "icname", mock_sm_session ) @@ -294,9 +324,32 @@ def test_get_model_id_version_from_endpoint_inference_component_endpoint_without "model_id", "model_version", "inferred-icname", + None, ) retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", "inferred-icname") + assert retval == ("model_id", "model_version", "inferred-icname", None) mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + + +@patch( + "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_component_" + "endpoint_without_inference_component_name" +) +def test_get_model_id_version_from_endpoint_inference_component_endpoint_with_config_name( + mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name, +): + mock_sm_session = Mock() + mock_sm_session.is_inference_component_based_endpoint.return_value = True + mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.return_value = ( + "model_id", + "model_version", + "inferred-icname", + "config_name", + ) + + retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", "inferred-icname", "config_name") + mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.assert_called_once() \ No newline at end of file diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index c1ea8abcb8..a0b5103e1a 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1322,7 +1322,7 @@ def test_no_model_id_no_version_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - (None, None), + (None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1339,7 +1339,7 @@ def test_model_id_no_version_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - ("model_id", None), + ("model_id", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1356,7 +1356,38 @@ def test_no_model_id_version_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - (None, "model_version"), + (None, "model_version", None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_no_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] + + self.assertEquals( + utils.get_jumpstart_model_id_version_from_resource_arn( + "some-arn", mock_sagemaker_session + ), + (None, None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name"} + ] + + self.assertEquals( + utils.get_jumpstart_model_id_version_from_resource_arn( + "some-arn", mock_sagemaker_session + ), + (None, None, "config_name"), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1374,7 +1405,7 @@ def test_model_id_version_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - ("model_id", "model_version"), + ("model_id", "model_version", None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1394,7 +1425,7 @@ def test_multiple_model_id_versions_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - (None, None), + (None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1414,7 +1445,7 @@ def test_multiple_model_id_versions_found_aliases_consistent(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - ("model_id_1", "model_version_1"), + ("model_id_1", "model_version_1", None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1434,7 +1465,27 @@ def test_multiple_model_id_versions_found_aliases_inconsistent(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - (None, None), + (None, None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_multiple_config_names_found_aliases_inconsistent(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.MODEL_ID, "Value": "model_id_1"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version_1"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name_1"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name_2"} + ] + + self.assertEquals( + utils.get_jumpstart_model_id_version_from_resource_arn( + "some-arn", mock_sagemaker_session + ), + ("model_id_1", "model_version_1", None), ) mock_list_tags.assert_called_once_with("some-arn") From 342f2144fd87466fdcf25b9d94182c346f629507 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Tue, 23 Apr 2024 21:44:31 +0000 Subject: [PATCH 02/10] format --- src/sagemaker/jumpstart/estimator.py | 2 +- src/sagemaker/jumpstart/factory/estimator.py | 5 ++++- src/sagemaker/jumpstart/session_utils.py | 8 +++++--- src/sagemaker/jumpstart/types.py | 4 ++-- src/sagemaker/jumpstart/utils.py | 3 ++- .../sagemaker/jumpstart/estimator/test_estimator.py | 2 +- tests/unit/sagemaker/jumpstart/test_session_utils.py | 8 ++++---- tests/unit/sagemaker/jumpstart/test_utils.py | 10 +++++----- 8 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 8ae9a11172..463a68b3d3 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -736,7 +736,7 @@ def attach( model_id, model_version, config_name = get_model_id_version_from_training_job( training_job_name=training_job_name, sagemaker_session=sagemaker_session ) - + model_version = model_version or "*" additional_kwargs = {"model_id": model_id, "model_version": model_version} diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 6f277e33d1..2d5b29b52f 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -478,7 +478,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima if kwargs.sagemaker_session.settings.include_jumpstart_tags: kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version, config_name=kwargs.config_name, + kwargs.tags, + kwargs.model_id, + full_model_version, + config_name=kwargs.config_name, ) return kwargs diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index 697fe63183..5e305a4635 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -196,9 +196,11 @@ def get_model_id_version_from_training_job( f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}" ) - model_id, inferred_model_version, config_name = get_jumpstart_model_id_version_from_resource_arn( - training_job_arn, sagemaker_session - ) + ( + model_id, + inferred_model_version, + config_name, + ) = get_jumpstart_model_id_version_from_resource_arn(training_job_arn, sagemaker_session) model_version = inferred_model_version or None diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 68f70d4536..a7f25d20c9 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1162,7 +1162,7 @@ def get_top_config_from_ranking( instance_type: Optional[str] = None, ) -> Optional[JumpStartMetadataConfig]: """Gets the best the config based on config ranking. - + Fallback to use the ordering in config names if ranking is not available. Args: @@ -1174,7 +1174,7 @@ def get_top_config_from_ranking( Raises: NotImplementedError: If the scope is unrecognized. """ - + if self.scope == JumpStartScriptScope.INFERENCE: instance_type_attribute = "supported_inference_instance_types" diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index cdc1c41630..6772870782 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -865,7 +865,8 @@ def get_jumpstart_model_id_version_from_resource_arn( if config_name_key_from_tag is not None: if config_name is not None and config_name_key != config_name: constants.JUMPSTART_LOGGER.warning( - "Found multiple model config names tags on the following resource: %s", resource_arn + "Found multiple model config names tags on the following resource: %s", + resource_arn ) config_name = None break diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 34478bae85..4d90857913 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1033,7 +1033,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( get_model_id_version_from_training_job.return_value = ( "js-trainable-model-prepacked", "1.0.0", - None + None, ) mock_get_model_specs.side_effect = get_special_model_spec diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py index c37998cecd..295e28570b 100644 --- a/tests/unit/sagemaker/jumpstart/test_session_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py @@ -23,7 +23,7 @@ def test_get_model_id_version_from_training_job_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( "model_id", "model_version", - None + None, ) retval = get_model_id_version_from_training_job("bLaH", sagemaker_session=mock_sm_session) @@ -46,7 +46,7 @@ def test_get_model_id_version_from_training_job_config_name( mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( "model_id", "model_version", - "config_name" + "config_name", ) retval = get_model_id_version_from_training_job("bLaH", sagemaker_session=mock_sm_session) @@ -297,7 +297,7 @@ def test_get_model_id_version_from_endpoint_inference_component_endpoint_with_in mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", - None + None, ) retval = get_model_id_version_from_endpoint( @@ -352,4 +352,4 @@ def test_get_model_id_version_from_endpoint_inference_component_endpoint_with_co retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) assert retval == ("model_id", "model_version", "inferred-icname", "config_name") - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.assert_called_once() \ No newline at end of file + mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.assert_called_once() diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index a0b5103e1a..a7293c497f 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1359,7 +1359,7 @@ def test_no_model_id_version_found(self): (None, "model_version", None), ) mock_list_tags.assert_called_once_with("some-arn") - + def test_no_config_name_found(self): mock_list_tags = Mock() mock_sagemaker_session = Mock() @@ -1373,14 +1373,14 @@ def test_no_config_name_found(self): (None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") - + def test_config_name_found(self): mock_list_tags = Mock() mock_sagemaker_session = Mock() mock_sagemaker_session.list_tags = mock_list_tags mock_list_tags.return_value = [ {"Key": "blah", "Value": "blah1"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name"} + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name"}, ] self.assertEquals( @@ -1468,7 +1468,7 @@ def test_multiple_model_id_versions_found_aliases_inconsistent(self): (None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") - + def test_multiple_config_names_found_aliases_inconsistent(self): mock_list_tags = Mock() mock_sagemaker_session = Mock() @@ -1478,7 +1478,7 @@ def test_multiple_config_names_found_aliases_inconsistent(self): {"Key": JumpStartTag.MODEL_ID, "Value": "model_id_1"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version_1"}, {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name_1"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name_2"} + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name_2"}, ] self.assertEquals( From 9c4b5c72a9e1a0153316cb203ef4b029d4b0e722 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 24 Apr 2024 15:27:34 +0000 Subject: [PATCH 03/10] resolving comments --- src/sagemaker/jumpstart/estimator.py | 4 +- src/sagemaker/jumpstart/session_utils.py | 22 ++-- src/sagemaker/jumpstart/types.py | 2 +- src/sagemaker/jumpstart/utils.py | 89 ++++++------- src/sagemaker/predictor.py | 4 +- .../jumpstart/estimator/test_estimator.py | 16 +-- .../sagemaker/jumpstart/test_predictor.py | 14 +- .../sagemaker/jumpstart/test_session_utils.py | 120 +++++++++--------- 8 files changed, 136 insertions(+), 135 deletions(-) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 463a68b3d3..06d1b6581c 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -33,7 +33,7 @@ from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs from sagemaker.jumpstart.factory.model import get_default_predictor -from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job +from sagemaker.jumpstart.session_utils import get_model_generic_info_from_training_job from sagemaker.jumpstart.types import JumpStartMetadataConfig from sagemaker.jumpstart.utils import ( get_jumpstart_configs, @@ -733,7 +733,7 @@ def attach( config_name = None if model_id is None: - model_id, model_version, config_name = get_model_id_version_from_training_job( + model_id, model_version, config_name = get_model_generic_info_from_training_job( training_job_name=training_job_name, sagemaker_session=sagemaker_session ) diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index 5e305a4635..c921ebbcc0 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -22,12 +22,12 @@ from sagemaker.utils import aws_partition -def get_model_id_version_from_endpoint( +def get_model_generic_info_from_endpoint( endpoint_name: str, inference_component_name: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Tuple[str, str, Optional[str], Optional[str]]: - """Given an endpoint and optionally inference component names, return the model ID, version and config name. + """Optionally inference component names, return the model ID, version and config name. Infers the model ID and version based on the resource tags. Returns a tuple of the model ID and version. A third string element is included in the tuple for any inferred inference @@ -47,7 +47,7 @@ def get_model_id_version_from_endpoint( model_id, model_version, config_name, - ) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 + ) = _get_model_generic_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 inference_component_name, sagemaker_session ) @@ -57,21 +57,21 @@ def get_model_id_version_from_endpoint( model_version, inference_component_name, config_name, - ) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 + ) = _get_model_generic_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 endpoint_name, sagemaker_session ) else: - model_id, model_version, config_name = _get_model_id_version_from_model_based_endpoint( + model_id, model_version, config_name = _get_model_generic_info_from_model_based_endpoint( endpoint_name, inference_component_name, sagemaker_session ) return model_id, model_version, inference_component_name, config_name -def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( +def _get_model_generic_info_from_inference_component_endpoint_without_inference_component_name( endpoint_name: str, sagemaker_session: Session ) -> Tuple[str, str, str]: - """Given an endpoint name, derives the model ID, version, config name and inferred inference component name. + """Derives the model ID, version, config name and inferred inference component name. This function assumes the endpoint corresponds to an inference-component-based endpoint. An endpoint is inference-component-based if and only if the associated endpoint config @@ -100,14 +100,14 @@ def _get_model_id_version_from_inference_component_endpoint_without_inference_co ) inference_component_name = inference_component_names[0] return ( - *_get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + *_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name( inference_component_name, sagemaker_session ), inference_component_name, ) -def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( +def _get_model_generic_info_from_inference_component_endpoint_with_inference_component_name( inference_component_name: str, sagemaker_session: Session ): """Returns the model ID and version inferred from a SageMaker inference component. @@ -139,7 +139,7 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo return model_id, model_version, config_name -def _get_model_id_version_from_model_based_endpoint( +def _get_model_generic_info_from_model_based_endpoint( endpoint_name: str, inference_component_name: Optional[str], sagemaker_session: Session, @@ -177,7 +177,7 @@ def _get_model_id_version_from_model_based_endpoint( return model_id, model_version, config_name -def get_model_id_version_from_training_job( +def get_model_generic_info_from_training_job( training_job_name: str, sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Tuple[str, str, Optional[str]]: diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index a7f25d20c9..db279e7c9c 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1186,7 +1186,7 @@ def get_top_config_from_ranking( if self.configs and ( not self.config_rankings or not self.config_rankings.get(ranking_name) ): - ranked_config_names = list(self.configs.keys()) + ranked_config_names = sorted(list(self.configs.keys())) else: rankings = self.config_rankings.get(ranking_name) ranked_config_names = rankings.rankings diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 6772870782..9c9c48680c 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -810,6 +810,33 @@ def validate_model_id_and_get_type( return None +def _extract_value_from_list_of_tags( + tag_keys: List[str], + list_tags_result: List[str], + resource_name: str, + resource_arn: str, +): + """Extracts value from list of tags with check of duplicate tags. + + Returns None if no value is found. + """ + resolved_value = None + for tag_key in tag_keys: + try: + value_from_tag = get_tag_value(tag_key, list_tags_result) + except KeyError: + continue + if value_from_tag is not None: + if resolved_value is not None and value_from_tag != resolved_value: + constants.JUMPSTART_LOGGER.warning( + f"Found multiple {resource_name} tags on the following resource: {resource_arn}" + ) + resolved_value = None + break + resolved_value = value_from_tag + return resolved_value + + def get_jumpstart_model_id_version_from_resource_arn( resource_arn: str, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -821,56 +848,30 @@ def get_jumpstart_model_id_version_from_resource_arn( list_tags_result = sagemaker_session.list_tags(resource_arn) - model_id: Optional[str] = None - model_version: Optional[str] = None - config_name: Optional[str] = None - model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS] model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS] model_config_name_keys = [enums.JumpStartTag.MODEL_CONFIG_NAME] - for model_id_key in model_id_keys: - try: - model_id_from_tag = get_tag_value(model_id_key, list_tags_result) - except KeyError: - continue - if model_id_from_tag is not None: - if model_id is not None and model_id_from_tag != model_id: - constants.JUMPSTART_LOGGER.warning( - "Found multiple model ID tags on the following resource: %s", resource_arn - ) - model_id = None - break - model_id = model_id_from_tag + model_id: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_id_keys, + list_tags_result=list_tags_result, + resource_name="model ID", + resource_arn=resource_arn, + ) - for model_version_key in model_version_keys: - try: - model_version_from_tag = get_tag_value(model_version_key, list_tags_result) - except KeyError: - continue - if model_version_from_tag is not None: - if model_version is not None and model_version_from_tag != model_version: - constants.JUMPSTART_LOGGER.warning( - "Found multiple model version tags on the following resource: %s", resource_arn - ) - model_version = None - break - model_version = model_version_from_tag + model_version: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_version_keys, + list_tags_result=list_tags_result, + resource_name="model version", + resource_arn=resource_arn, + ) - for config_name_key in model_config_name_keys: - try: - config_name_key_from_tag = get_tag_value(config_name_key, list_tags_result) - except KeyError: - continue - if config_name_key_from_tag is not None: - if config_name is not None and config_name_key != config_name: - constants.JUMPSTART_LOGGER.warning( - "Found multiple model config names tags on the following resource: %s", - resource_arn - ) - config_name = None - break - config_name = config_name_key_from_tag + config_name: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_config_name_keys, + list_tags_result=list_tags_result, + resource_name="model config name", + resource_arn=resource_arn, + ) return model_id, model_version, config_name diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 277d8de830..72e51e583a 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -18,7 +18,7 @@ from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.factory.model import get_default_predictor -from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint +from sagemaker.jumpstart.session_utils import get_model_generic_info_from_endpoint from sagemaker.session import Session @@ -79,7 +79,7 @@ def retrieve_default( inferred_model_version, inferred_inference_component_name, inferred_config_name, - ) = get_model_id_version_from_endpoint( + ) = get_model_generic_info_from_endpoint( endpoint_name, inference_component_name, sagemaker_session ) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 4d90857913..313d312863 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1016,7 +1016,7 @@ def test_jumpstart_estimator_attach_eula_model( ) @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") - @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.get_model_generic_info_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1024,13 +1024,13 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, - get_model_id_version_from_training_job: mock.Mock, + get_model_generic_info_from_training_job: mock.Mock, mock_attach: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - get_model_id_version_from_training_job.return_value = ( + get_model_generic_info_from_training_job.return_value = ( "js-trainable-model-prepacked", "1.0.0", None, @@ -1044,7 +1044,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( training_job_name="some-training-job-name", sagemaker_session=mock_session ) - get_model_id_version_from_training_job.assert_called_once_with( + get_model_generic_info_from_training_job.assert_called_once_with( training_job_name="some-training-job-name", sagemaker_session=mock_session, ) @@ -1060,7 +1060,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( ) @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") - @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.get_model_generic_info_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1068,13 +1068,13 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, - get_model_id_version_from_training_job: mock.Mock, + get_model_generic_info_from_training_job: mock.Mock, mock_attach: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - get_model_id_version_from_training_job.side_effect = ValueError() + get_model_generic_info_from_training_job.side_effect = ValueError() mock_get_model_specs.side_effect = get_special_model_spec @@ -1085,7 +1085,7 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( training_job_name="some-training-job-name", sagemaker_session=mock_session ) - get_model_id_version_from_training_job.assert_called_once_with( + get_model_generic_info_from_training_job.assert_called_once_with( training_job_name="some-training-job-name", sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 7cf049ead2..fe22fb6d3a 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -18,7 +18,7 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec, get_spec_from_base_spec -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_generic_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support( @@ -52,7 +52,7 @@ def test_jumpstart_predictor_support( assert js_predictor.accept == MIMEType.JSON -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_generic_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_proprietary_predictor_support( @@ -91,7 +91,7 @@ def test_proprietary_predictor_support( @patch("sagemaker.predictor.Predictor") @patch("sagemaker.predictor.get_default_predictor") -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_generic_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( @@ -134,7 +134,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( @patch("sagemaker.predictor.get_default_predictor") -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_generic_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( @@ -161,7 +161,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( patched_get_default_predictor.assert_not_called() -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_generic_info_from_endpoint") @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") @patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @@ -171,7 +171,7 @@ def test_jumpstart_serializable_payload_with_predictor( patched_verify_model_region_and_return_specs, patched_validate_model_id_and_get_type, patched_get_object_cached, - patched_get_model_id_version_from_endpoint, + patched_get_model_generic_info_from_endpoint, ): patched_get_object_cached.return_value = base64.b64decode("encodedimage") @@ -181,7 +181,7 @@ def test_jumpstart_serializable_payload_with_predictor( patched_get_model_specs.side_effect = get_special_model_spec model_id, model_version = "default_payloads", "*" - patched_get_model_id_version_from_endpoint.return_value = model_id, model_version, None + patched_get_model_generic_info_from_endpoint.return_value = model_id, model_version, None js_predictor = predictor.retrieve_default( endpoint_name="blah", model_id=model_id, model_version=model_version diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py index 295e28570b..6f37c64321 100644 --- a/tests/unit/sagemaker/jumpstart/test_session_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py @@ -4,16 +4,16 @@ import pytest from sagemaker.jumpstart.session_utils import ( - _get_model_id_version_from_inference_component_endpoint_with_inference_component_name, - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name, - _get_model_id_version_from_model_based_endpoint, - get_model_id_version_from_endpoint, - get_model_id_version_from_training_job, + _get_model_generic_info_from_inference_component_endpoint_with_inference_component_name, + _get_model_generic_info_from_inference_component_endpoint_without_inference_component_name, + _get_model_generic_info_from_model_based_endpoint, + get_model_generic_info_from_endpoint, + get_model_generic_info_from_training_job, ) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_training_job_happy_case( +def test_get_model_generic_info_from_training_job_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -26,7 +26,7 @@ def test_get_model_id_version_from_training_job_happy_case( None, ) - retval = get_model_id_version_from_training_job("bLaH", sagemaker_session=mock_sm_session) + retval = get_model_generic_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) assert retval == ("model_id", "model_version", None) @@ -36,7 +36,7 @@ def test_get_model_id_version_from_training_job_happy_case( @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_training_job_config_name( +def test_get_model_generic_info_from_training_job_config_name( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -49,7 +49,7 @@ def test_get_model_id_version_from_training_job_config_name( "config_name", ) - retval = get_model_id_version_from_training_job("bLaH", sagemaker_session=mock_sm_session) + retval = get_model_generic_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) assert retval == ("model_id", "model_version", "config_name") @@ -59,7 +59,7 @@ def test_get_model_id_version_from_training_job_config_name( @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_training_job_no_model_id_inferred( +def test_get_model_generic_info_from_training_job_no_model_id_inferred( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -72,11 +72,11 @@ def test_get_model_id_version_from_training_job_no_model_id_inferred( ) with pytest.raises(ValueError): - get_model_id_version_from_training_job("blah", sagemaker_session=mock_sm_session) + get_model_generic_info_from_training_job("blah", sagemaker_session=mock_sm_session) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_happy_case( +def test_get_model_generic_info_from_model_based_endpoint_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -89,7 +89,7 @@ def test_get_model_id_version_from_model_based_endpoint_happy_case( None, ) - retval = _get_model_id_version_from_model_based_endpoint( + retval = _get_model_generic_info_from_model_based_endpoint( "bLaH", inference_component_name=None, sagemaker_session=mock_sm_session ) @@ -101,7 +101,7 @@ def test_get_model_id_version_from_model_based_endpoint_happy_case( @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_inference_component_supplied( +def test_get_model_generic_info_from_model_based_endpoint_inference_component_supplied( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -115,13 +115,13 @@ def test_get_model_id_version_from_model_based_endpoint_inference_component_supp ) with pytest.raises(ValueError): - _get_model_id_version_from_model_based_endpoint( + _get_model_generic_info_from_model_based_endpoint( "blah", inference_component_name="some-name", sagemaker_session=mock_sm_session ) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_no_model_id_inferred( +def test_get_model_generic_info_from_model_based_endpoint_no_model_id_inferred( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -134,13 +134,13 @@ def test_get_model_id_version_from_model_based_endpoint_no_model_id_inferred( ) with pytest.raises(ValueError): - _get_model_id_version_from_model_based_endpoint( + _get_model_generic_info_from_model_based_endpoint( "blah", inference_component_name="some-name", sagemaker_session=mock_sm_session ) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_inference_component_endpoint_with_inference_component_name_happy_case( +def test_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -153,7 +153,7 @@ def test_get_model_id_version_from_inference_component_endpoint_with_inference_c None, ) - retval = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + retval = _get_model_generic_info_from_inference_component_endpoint_with_inference_component_name( "bLaH", sagemaker_session=mock_sm_session ) @@ -165,7 +165,7 @@ def test_get_model_id_version_from_inference_component_endpoint_with_inference_c @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( +def test_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -179,20 +179,20 @@ def test_get_model_id_version_from_inference_component_endpoint_with_inference_c ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + _get_model_generic_info_from_inference_component_endpoint_with_inference_component_name( "blah", sagemaker_session=mock_sm_session ) @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_generic_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_inference_component_endpoint_without_inference_component_name_happy_case( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_generic_info_from_inference_component_endpoint_without_inference_component_name_happy_case( + mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -201,7 +201,7 @@ def test_get_model_id_version_from_inference_component_endpoint_without_inferenc ) retval = ( - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( + _get_model_generic_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) ) @@ -213,14 +213,14 @@ def test_get_model_id_version_from_inference_component_endpoint_without_inferenc @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_generic_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_generic_info_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint( + mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -228,7 +228,7 @@ def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_ return_value=[] ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( + _get_model_generic_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) @@ -238,14 +238,14 @@ def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_ @patch( - "sagemaker.jumpstart.session_utils._get_model_id" - "_version_from_inference_component_endpoint_with_inference_component_name" + "sagemaker.jumpstart.session_utils._get_model" + "_generic_info_from_inference_component_endpoint_with_inference_component_name" ) def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_multiple_ics_for_endpoint( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, + mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -255,7 +255,7 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_ ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( + _get_model_generic_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) @@ -264,92 +264,92 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_ ) -@patch("sagemaker.jumpstart.session_utils._get_model_id_version_from_model_based_endpoint") -def test_get_model_id_version_from_endpoint_non_inference_component_endpoint( - mock_get_model_id_version_from_model_based_endpoint, +@patch("sagemaker.jumpstart.session_utils._get_model_generic_info_from_model_based_endpoint") +def test_get_model_generic_info_from_endpoint_non_inference_component_endpoint( + mock_get_model_generic_info_from_model_based_endpoint, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = False - mock_get_model_id_version_from_model_based_endpoint.return_value = ( + mock_get_model_generic_info_from_model_based_endpoint.return_value = ( "model_id", "model_version", None, ) - retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_generic_info_from_endpoint("blah", sagemaker_session=mock_sm_session) assert retval == ("model_id", "model_version", None, None) - mock_get_model_id_version_from_model_based_endpoint.assert_called_once_with( + mock_get_model_generic_info_from_model_based_endpoint.assert_called_once_with( "blah", None, mock_sm_session ) mock_sm_session.is_inference_component_based_endpoint.assert_called_once_with("blah") @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_generic_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_endpoint_inference_component_endpoint_with_inference_component_name( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_generic_info_from_endpoint_inference_component_endpoint_with_inference_component_name( + mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", None, ) - retval = get_model_id_version_from_endpoint( + retval = get_model_generic_info_from_endpoint( "blah", inference_component_name="icname", sagemaker_session=mock_sm_session ) assert retval == ("model_id", "model_version", "icname", None) - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( + mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( "icname", mock_sm_session ) mock_sm_session.is_inference_component_based_endpoint.assert_not_called() @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_component_" + "sagemaker.jumpstart.session_utils._get_model_generic_info_from_inference_component_" "endpoint_without_inference_component_name" ) -def test_get_model_id_version_from_endpoint_inference_component_endpoint_without_inference_component_name( - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name, +def test_get_model_generic_info_from_endpoint_inference_component_endpoint_without_inference_component_name( + mock_get_model_generic_info_from_inference_component_endpoint_without_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.return_value = ( + mock_get_model_generic_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( "model_id", "model_version", "inferred-icname", None, ) - retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_generic_info_from_endpoint("blah", sagemaker_session=mock_sm_session) assert retval == ("model_id", "model_version", "inferred-icname", None) - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + mock_get_model_generic_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_component_" + "sagemaker.jumpstart.session_utils._get_model_generic_info_from_inference_component_" "endpoint_without_inference_component_name" ) -def test_get_model_id_version_from_endpoint_inference_component_endpoint_with_config_name( - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name, +def test_get_model_generic_info_from_endpoint_inference_component_endpoint_with_config_name( + mock_get_model_generic_info_from_inference_component_endpoint_without_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.return_value = ( + mock_get_model_generic_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( "model_id", "model_version", "inferred-icname", "config_name", ) - retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_generic_info_from_endpoint("blah", sagemaker_session=mock_sm_session) assert retval == ("model_id", "model_version", "inferred-icname", "config_name") - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + mock_get_model_generic_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() From 7d6b8fadb1ec6316c16a712e93e9b681fa2d4b84 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 24 Apr 2024 16:23:18 +0000 Subject: [PATCH 04/10] format --- src/sagemaker/jumpstart/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 9c9c48680c..839191470c 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -829,7 +829,9 @@ def _extract_value_from_list_of_tags( if value_from_tag is not None: if resolved_value is not None and value_from_tag != resolved_value: constants.JUMPSTART_LOGGER.warning( - f"Found multiple {resource_name} tags on the following resource: {resource_arn}" + "Found multiple %s tags on the following resource: %s", + resource_name, + resource_arn, ) resolved_value = None break From c9201a09f8355c27e19ba67a3be143d7c36511ab Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 24 Apr 2024 19:07:04 +0000 Subject: [PATCH 05/10] format --- tests/unit/sagemaker/jumpstart/test_predictor.py | 4 +--- tests/unit/sagemaker/jumpstart/test_session_utils.py | 6 ++++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 8481d28ba8..9dc7b0fb5d 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -160,9 +160,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( ) patched_get_default_predictor.assert_not_called() -@patch( - "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} -) +@patch("sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}) @patch("sagemaker.predictor.get_model_generic_info_from_endpoint") @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") @patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py index 6f37c64321..a7edcd9f10 100644 --- a/tests/unit/sagemaker/jumpstart/test_session_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py @@ -153,8 +153,10 @@ def test_get_model_generic_info_from_inference_component_endpoint_with_inference None, ) - retval = _get_model_generic_info_from_inference_component_endpoint_with_inference_component_name( - "bLaH", sagemaker_session=mock_sm_session + retval = ( + _get_model_generic_info_from_inference_component_endpoint_with_inference_component_name( + "bLaH", sagemaker_session=mock_sm_session + ) ) assert retval == ("model_id", "model_version", None) From e1a8edb6732e53a40f1a53677ee9e73a6da7e139 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 24 Apr 2024 19:46:02 +0000 Subject: [PATCH 06/10] update --- src/sagemaker/jumpstart/estimator.py | 4 +- src/sagemaker/jumpstart/session_utils.py | 18 +-- src/sagemaker/jumpstart/types.py | 1 - src/sagemaker/predictor.py | 4 +- .../jumpstart/estimator/test_estimator.py | 16 +-- .../sagemaker/jumpstart/test_predictor.py | 15 +-- .../sagemaker/jumpstart/test_session_utils.py | 118 +++++++++--------- 7 files changed, 88 insertions(+), 88 deletions(-) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 06d1b6581c..33bb73a83c 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -33,7 +33,7 @@ from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs from sagemaker.jumpstart.factory.model import get_default_predictor -from sagemaker.jumpstart.session_utils import get_model_generic_info_from_training_job +from sagemaker.jumpstart.session_utils import get_model_info_from_training_job from sagemaker.jumpstart.types import JumpStartMetadataConfig from sagemaker.jumpstart.utils import ( get_jumpstart_configs, @@ -733,7 +733,7 @@ def attach( config_name = None if model_id is None: - model_id, model_version, config_name = get_model_generic_info_from_training_job( + model_id, model_version, config_name = get_model_info_from_training_job( training_job_name=training_job_name, sagemaker_session=sagemaker_session ) diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index c921ebbcc0..87c510edaa 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -22,7 +22,7 @@ from sagemaker.utils import aws_partition -def get_model_generic_info_from_endpoint( +def get_model_info_from_endpoint( endpoint_name: str, inference_component_name: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -47,7 +47,7 @@ def get_model_generic_info_from_endpoint( model_id, model_version, config_name, - ) = _get_model_generic_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 + ) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 inference_component_name, sagemaker_session ) @@ -57,18 +57,18 @@ def get_model_generic_info_from_endpoint( model_version, inference_component_name, config_name, - ) = _get_model_generic_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 + ) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 endpoint_name, sagemaker_session ) else: - model_id, model_version, config_name = _get_model_generic_info_from_model_based_endpoint( + model_id, model_version, config_name = _get_model_info_from_model_based_endpoint( endpoint_name, inference_component_name, sagemaker_session ) return model_id, model_version, inference_component_name, config_name -def _get_model_generic_info_from_inference_component_endpoint_without_inference_component_name( +def _get_model_info_from_inference_component_endpoint_without_inference_component_name( endpoint_name: str, sagemaker_session: Session ) -> Tuple[str, str, str]: """Derives the model ID, version, config name and inferred inference component name. @@ -100,14 +100,14 @@ def _get_model_generic_info_from_inference_component_endpoint_without_inference_ ) inference_component_name = inference_component_names[0] return ( - *_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name( + *_get_model_info_from_inference_component_endpoint_with_inference_component_name( inference_component_name, sagemaker_session ), inference_component_name, ) -def _get_model_generic_info_from_inference_component_endpoint_with_inference_component_name( +def _get_model_info_from_inference_component_endpoint_with_inference_component_name( inference_component_name: str, sagemaker_session: Session ): """Returns the model ID and version inferred from a SageMaker inference component. @@ -139,7 +139,7 @@ def _get_model_generic_info_from_inference_component_endpoint_with_inference_com return model_id, model_version, config_name -def _get_model_generic_info_from_model_based_endpoint( +def _get_model_info_from_model_based_endpoint( endpoint_name: str, inference_component_name: Optional[str], sagemaker_session: Session, @@ -177,7 +177,7 @@ def _get_model_generic_info_from_model_based_endpoint( return model_id, model_version, config_name -def get_model_generic_info_from_training_job( +def get_model_info_from_training_job( training_job_name: str, sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Tuple[str, str, Optional[str]]: diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 8f0009a720..bf0a84319b 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1175,7 +1175,6 @@ def get_top_config_from_ranking( NotImplementedError: If the scope is unrecognized. """ - if self.scope == JumpStartScriptScope.INFERENCE: instance_type_attribute = "supported_inference_instance_types" elif self.scope == JumpStartScriptScope.TRAINING: diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 72e51e583a..4a696d8b86 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -18,7 +18,7 @@ from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.factory.model import get_default_predictor -from sagemaker.jumpstart.session_utils import get_model_generic_info_from_endpoint +from sagemaker.jumpstart.session_utils import get_model_info_from_endpoint from sagemaker.session import Session @@ -79,7 +79,7 @@ def retrieve_default( inferred_model_version, inferred_inference_component_name, inferred_config_name, - ) = get_model_generic_info_from_endpoint( + ) = get_model_info_from_endpoint( endpoint_name, inference_component_name, sagemaker_session ) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 313d312863..1ed860fc05 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1016,7 +1016,7 @@ def test_jumpstart_estimator_attach_eula_model( ) @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") - @mock.patch("sagemaker.jumpstart.estimator.get_model_generic_info_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1024,13 +1024,13 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, - get_model_generic_info_from_training_job: mock.Mock, + get_model_info_from_training_job: mock.Mock, mock_attach: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - get_model_generic_info_from_training_job.return_value = ( + get_model_info_from_training_job.return_value = ( "js-trainable-model-prepacked", "1.0.0", None, @@ -1044,7 +1044,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( training_job_name="some-training-job-name", sagemaker_session=mock_session ) - get_model_generic_info_from_training_job.assert_called_once_with( + get_model_info_from_training_job.assert_called_once_with( training_job_name="some-training-job-name", sagemaker_session=mock_session, ) @@ -1060,7 +1060,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( ) @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") - @mock.patch("sagemaker.jumpstart.estimator.get_model_generic_info_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1068,13 +1068,13 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, - get_model_generic_info_from_training_job: mock.Mock, + get_model_info_from_training_job: mock.Mock, mock_attach: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - get_model_generic_info_from_training_job.side_effect = ValueError() + get_model_info_from_training_job.side_effect = ValueError() mock_get_model_specs.side_effect = get_special_model_spec @@ -1085,7 +1085,7 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( training_job_name="some-training-job-name", sagemaker_session=mock_session ) - get_model_generic_info_from_training_job.assert_called_once_with( + get_model_info_from_training_job.assert_called_once_with( training_job_name="some-training-job-name", sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 9dc7b0fb5d..1cc8f292f0 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -18,7 +18,7 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec, get_spec_from_base_spec -@patch("sagemaker.predictor.get_model_generic_info_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support( @@ -52,7 +52,7 @@ def test_jumpstart_predictor_support( assert js_predictor.accept == MIMEType.JSON -@patch("sagemaker.predictor.get_model_generic_info_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_proprietary_predictor_support( @@ -91,7 +91,7 @@ def test_proprietary_predictor_support( @patch("sagemaker.predictor.Predictor") @patch("sagemaker.predictor.get_default_predictor") -@patch("sagemaker.predictor.get_model_generic_info_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( @@ -134,7 +134,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( @patch("sagemaker.predictor.get_default_predictor") -@patch("sagemaker.predictor.get_model_generic_info_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( @@ -160,8 +160,9 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( ) patched_get_default_predictor.assert_not_called() + @patch("sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}) -@patch("sagemaker.predictor.get_model_generic_info_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") @patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @@ -171,7 +172,7 @@ def test_jumpstart_serializable_payload_with_predictor( patched_verify_model_region_and_return_specs, patched_validate_model_id_and_get_type, patched_get_object_cached, - patched_get_model_generic_info_from_endpoint, + patched_get_model_info_from_endpoint, patched_get_jumpstart_configs, ): @@ -182,7 +183,7 @@ def test_jumpstart_serializable_payload_with_predictor( patched_get_model_specs.side_effect = get_special_model_spec model_id, model_version = "default_payloads", "*" - patched_get_model_generic_info_from_endpoint.return_value = model_id, model_version, None + patched_get_model_info_from_endpoint.return_value = model_id, model_version, None js_predictor = predictor.retrieve_default( endpoint_name="blah", model_id=model_id, model_version=model_version diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py index a7edcd9f10..e034bc6ca6 100644 --- a/tests/unit/sagemaker/jumpstart/test_session_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py @@ -4,16 +4,16 @@ import pytest from sagemaker.jumpstart.session_utils import ( - _get_model_generic_info_from_inference_component_endpoint_with_inference_component_name, - _get_model_generic_info_from_inference_component_endpoint_without_inference_component_name, - _get_model_generic_info_from_model_based_endpoint, - get_model_generic_info_from_endpoint, - get_model_generic_info_from_training_job, + _get_model_info_from_inference_component_endpoint_with_inference_component_name, + _get_model_info_from_inference_component_endpoint_without_inference_component_name, + _get_model_info_from_model_based_endpoint, + get_model_info_from_endpoint, + get_model_info_from_training_job, ) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_generic_info_from_training_job_happy_case( +def test_get_model_info_from_training_job_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -26,7 +26,7 @@ def test_get_model_generic_info_from_training_job_happy_case( None, ) - retval = get_model_generic_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) + retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) assert retval == ("model_id", "model_version", None) @@ -36,7 +36,7 @@ def test_get_model_generic_info_from_training_job_happy_case( @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_generic_info_from_training_job_config_name( +def test_get_model_info_from_training_job_config_name( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -49,7 +49,7 @@ def test_get_model_generic_info_from_training_job_config_name( "config_name", ) - retval = get_model_generic_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) + retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) assert retval == ("model_id", "model_version", "config_name") @@ -59,7 +59,7 @@ def test_get_model_generic_info_from_training_job_config_name( @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_generic_info_from_training_job_no_model_id_inferred( +def test_get_model_info_from_training_job_no_model_id_inferred( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -72,11 +72,11 @@ def test_get_model_generic_info_from_training_job_no_model_id_inferred( ) with pytest.raises(ValueError): - get_model_generic_info_from_training_job("blah", sagemaker_session=mock_sm_session) + get_model_info_from_training_job("blah", sagemaker_session=mock_sm_session) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_generic_info_from_model_based_endpoint_happy_case( +def test_get_model_info_from_model_based_endpoint_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -89,7 +89,7 @@ def test_get_model_generic_info_from_model_based_endpoint_happy_case( None, ) - retval = _get_model_generic_info_from_model_based_endpoint( + retval = _get_model_info_from_model_based_endpoint( "bLaH", inference_component_name=None, sagemaker_session=mock_sm_session ) @@ -101,7 +101,7 @@ def test_get_model_generic_info_from_model_based_endpoint_happy_case( @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_generic_info_from_model_based_endpoint_inference_component_supplied( +def test_get_model_info_from_model_based_endpoint_inference_component_supplied( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -115,13 +115,13 @@ def test_get_model_generic_info_from_model_based_endpoint_inference_component_su ) with pytest.raises(ValueError): - _get_model_generic_info_from_model_based_endpoint( + _get_model_info_from_model_based_endpoint( "blah", inference_component_name="some-name", sagemaker_session=mock_sm_session ) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_generic_info_from_model_based_endpoint_no_model_id_inferred( +def test_get_model_info_from_model_based_endpoint_no_model_id_inferred( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -134,13 +134,13 @@ def test_get_model_generic_info_from_model_based_endpoint_no_model_id_inferred( ) with pytest.raises(ValueError): - _get_model_generic_info_from_model_based_endpoint( + _get_model_info_from_model_based_endpoint( "blah", inference_component_name="some-name", sagemaker_session=mock_sm_session ) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name_happy_case( +def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -154,7 +154,7 @@ def test_get_model_generic_info_from_inference_component_endpoint_with_inference ) retval = ( - _get_model_generic_info_from_inference_component_endpoint_with_inference_component_name( + _get_model_info_from_inference_component_endpoint_with_inference_component_name( "bLaH", sagemaker_session=mock_sm_session ) ) @@ -167,7 +167,7 @@ def test_get_model_generic_info_from_inference_component_endpoint_with_inference @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( +def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -181,20 +181,20 @@ def test_get_model_generic_info_from_inference_component_endpoint_with_inference ) with pytest.raises(ValueError): - _get_model_generic_info_from_inference_component_endpoint_with_inference_component_name( + _get_model_info_from_inference_component_endpoint_with_inference_component_name( "blah", sagemaker_session=mock_sm_session ) @patch( - "sagemaker.jumpstart.session_utils._get_model_generic_info_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_generic_info_from_inference_component_endpoint_without_inference_component_name_happy_case( - mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_inference_component_endpoint_without_inference_component_name_happy_case( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -203,7 +203,7 @@ def test_get_model_generic_info_from_inference_component_endpoint_without_infere ) retval = ( - _get_model_generic_info_from_inference_component_endpoint_without_inference_component_name( + _get_model_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) ) @@ -215,14 +215,14 @@ def test_get_model_generic_info_from_inference_component_endpoint_without_infere @patch( - "sagemaker.jumpstart.session_utils._get_model_generic_info_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_generic_info_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint( - mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -230,7 +230,7 @@ def test_get_model_generic_info_from_inference_component_endpoint_without_ic_nam return_value=[] ) with pytest.raises(ValueError): - _get_model_generic_info_from_inference_component_endpoint_without_inference_component_name( + _get_model_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) @@ -241,13 +241,13 @@ def test_get_model_generic_info_from_inference_component_endpoint_without_ic_nam @patch( "sagemaker.jumpstart.session_utils._get_model" - "_generic_info_from_inference_component_endpoint_with_inference_component_name" + "_info_from_inference_component_endpoint_with_inference_component_name" ) def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_multiple_ics_for_endpoint( - mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name, + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -257,7 +257,7 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_ ) with pytest.raises(ValueError): - _get_model_generic_info_from_inference_component_endpoint_without_inference_component_name( + _get_model_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) @@ -266,92 +266,92 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_ ) -@patch("sagemaker.jumpstart.session_utils._get_model_generic_info_from_model_based_endpoint") -def test_get_model_generic_info_from_endpoint_non_inference_component_endpoint( - mock_get_model_generic_info_from_model_based_endpoint, +@patch("sagemaker.jumpstart.session_utils._get_model_info_from_model_based_endpoint") +def test_get_model_info_from_endpoint_non_inference_component_endpoint( + mock_get_model_info_from_model_based_endpoint, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = False - mock_get_model_generic_info_from_model_based_endpoint.return_value = ( + mock_get_model_info_from_model_based_endpoint.return_value = ( "model_id", "model_version", None, ) - retval = get_model_generic_info_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) assert retval == ("model_id", "model_version", None, None) - mock_get_model_generic_info_from_model_based_endpoint.assert_called_once_with( + mock_get_model_info_from_model_based_endpoint.assert_called_once_with( "blah", None, mock_sm_session ) mock_sm_session.is_inference_component_based_endpoint.assert_called_once_with("blah") @patch( - "sagemaker.jumpstart.session_utils._get_model_generic_info_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_generic_info_from_endpoint_inference_component_endpoint_with_inference_component_name( - mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_endpoint_inference_component_endpoint_with_inference_component_name( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", None, ) - retval = get_model_generic_info_from_endpoint( + retval = get_model_info_from_endpoint( "blah", inference_component_name="icname", sagemaker_session=mock_sm_session ) assert retval == ("model_id", "model_version", "icname", None) - mock_get_model_generic_info_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( "icname", mock_sm_session ) mock_sm_session.is_inference_component_based_endpoint.assert_not_called() @patch( - "sagemaker.jumpstart.session_utils._get_model_generic_info_from_inference_component_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" "endpoint_without_inference_component_name" ) -def test_get_model_generic_info_from_endpoint_inference_component_endpoint_without_inference_component_name( - mock_get_model_generic_info_from_inference_component_endpoint_without_inference_component_name, +def test_get_model_info_from_endpoint_inference_component_endpoint_without_inference_component_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_generic_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( "model_id", "model_version", "inferred-icname", None, ) - retval = get_model_generic_info_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) assert retval == ("model_id", "model_version", "inferred-icname", None) - mock_get_model_generic_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() @patch( - "sagemaker.jumpstart.session_utils._get_model_generic_info_from_inference_component_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" "endpoint_without_inference_component_name" ) -def test_get_model_generic_info_from_endpoint_inference_component_endpoint_with_config_name( - mock_get_model_generic_info_from_inference_component_endpoint_without_inference_component_name, +def test_get_model_info_from_endpoint_inference_component_endpoint_with_config_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_generic_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( "model_id", "model_version", "inferred-icname", "config_name", ) - retval = get_model_generic_info_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) assert retval == ("model_id", "model_version", "inferred-icname", "config_name") - mock_get_model_generic_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() From cf0081ccbea36b30c5b918d433e3e1de603e8f66 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 24 Apr 2024 20:09:15 +0000 Subject: [PATCH 07/10] fix --- src/sagemaker/jumpstart/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index ce869dabc8..59bf11b415 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -162,7 +162,6 @@ def get_jumpstart_content_bucket( for info_log in info_logs: constants.JUMPSTART_LOGGER.info(info_log) return bucket_to_return - # return "jumpstart-cache-alpha-us-west-2" def get_formatted_manifest( From 6a365dec5a9cae39069ab64cab469cad132ac36c Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 24 Apr 2024 20:21:13 +0000 Subject: [PATCH 08/10] format --- src/sagemaker/predictor.py | 4 +--- tests/unit/sagemaker/jumpstart/test_session_utils.py | 12 ++++-------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 4a696d8b86..14e2ae278b 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -79,9 +79,7 @@ def retrieve_default( inferred_model_version, inferred_inference_component_name, inferred_config_name, - ) = get_model_info_from_endpoint( - endpoint_name, inference_component_name, sagemaker_session - ) + ) = get_model_info_from_endpoint(endpoint_name, inference_component_name, sagemaker_session) if not inferred_model_id: raise ValueError( diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py index e034bc6ca6..47ee4e839f 100644 --- a/tests/unit/sagemaker/jumpstart/test_session_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py @@ -153,10 +153,8 @@ def test_get_model_info_from_inference_component_endpoint_with_inference_compone None, ) - retval = ( - _get_model_info_from_inference_component_endpoint_with_inference_component_name( - "bLaH", sagemaker_session=mock_sm_session - ) + retval = _get_model_info_from_inference_component_endpoint_with_inference_component_name( + "bLaH", sagemaker_session=mock_sm_session ) assert retval == ("model_id", "model_version", None) @@ -202,10 +200,8 @@ def test_get_model_info_from_inference_component_endpoint_without_inference_comp return_value=["icname"] ) - retval = ( - _get_model_info_from_inference_component_endpoint_without_inference_component_name( - "blahblah", mock_sm_session - ) + retval = _get_model_info_from_inference_component_endpoint_without_inference_component_name( + "blahblah", mock_sm_session ) assert retval == ("model_id", "model_version", "icname") From 8c9c240dbb1889a978616c5351009f5edb6493b9 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Thu, 25 Apr 2024 02:03:44 +0000 Subject: [PATCH 09/10] updates inference component config name --- src/sagemaker/jumpstart/session_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index 87c510edaa..0fa7722f91 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -55,8 +55,8 @@ def get_model_info_from_endpoint( ( model_id, model_version, - inference_component_name, config_name, + inference_component_name, ) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 endpoint_name, sagemaker_session ) @@ -70,7 +70,7 @@ def get_model_info_from_endpoint( def _get_model_info_from_inference_component_endpoint_without_inference_component_name( endpoint_name: str, sagemaker_session: Session -) -> Tuple[str, str, str]: +) -> Tuple[str, str, str, str]: """Derives the model ID, version, config name and inferred inference component name. This function assumes the endpoint corresponds to an inference-component-based endpoint. From 8ecccf83946d45552626f5ec2d2166de5180358b Mon Sep 17 00:00:00 2001 From: Haotian An Date: Thu, 25 Apr 2024 03:03:23 +0000 Subject: [PATCH 10/10] fix: tests --- tests/unit/sagemaker/jumpstart/test_session_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py index 47ee4e839f..9dc8acb32a 100644 --- a/tests/unit/sagemaker/jumpstart/test_session_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py @@ -321,8 +321,8 @@ def test_get_model_info_from_endpoint_inference_component_endpoint_without_infer mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( "model_id", "model_version", - "inferred-icname", None, + "inferred-icname", ) retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) @@ -343,8 +343,8 @@ def test_get_model_info_from_endpoint_inference_component_endpoint_with_config_n mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( "model_id", "model_version", - "inferred-icname", "config_name", + "inferred-icname", ) retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session)