diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 732493ce3b..83613cd71b 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -329,9 +329,12 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin return sorted(list(model_id_version_dict.keys())) if not list_old_models: - model_id_version_dict = { - model_id: set([max(versions)]) for model_id, versions in model_id_version_dict.items() - } + for model_id, versions in model_id_version_dict.items(): + try: + model_id_version_dict.update({model_id: set([max(versions)])}) + except TypeError: + versions = [str(v) for v in versions] + model_id_version_dict.update({model_id: set([max(versions)])}) model_id_version_set: Set[Tuple[str, str]] = set() for model_id in model_id_version_dict: diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py index 595f801598..e4d31e9c83 100644 --- a/src/sagemaker/jumpstart/payload_utils.py +++ b/src/sagemaker/jumpstart/payload_utils.py @@ -23,7 +23,7 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -from sagemaker.jumpstart.enums import MIMEType +from sagemaker.jumpstart.enums import JumpStartModelType, MIMEType from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, @@ -61,6 +61,7 @@ def _construct_payload( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[JumpStartSerializablePayload]: """Returns example payload from prompt. @@ -83,6 +84,8 @@ def _construct_payload( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + model_type (JumpStartModelType): The type of the model, can be open weights model or + proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: Optional[JumpStartSerializablePayload]: serializable payload with prompt, or None if this feature is unavailable for the specified model. @@ -94,6 +97,7 @@ def _construct_payload( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if payloads is None or len(payloads) == 0: return None diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index c00d271ef1..6544c59019 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -3,7 +3,6 @@ from unittest import TestCase from unittest.mock import Mock, patch -import datetime import pytest from sagemaker.jumpstart.constants import ( @@ -17,7 +16,6 @@ get_prototype_manifest, get_prototype_model_spec, ) -from tests.unit.sagemaker.jumpstart.constants import BASE_PROPRIETARY_MANIFEST from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.notebook_utils import ( _generate_jumpstart_model_versions, @@ -227,10 +225,6 @@ def test_list_jumpstart_models_simple_case( patched_get_manifest.assert_called() patched_get_model_specs.assert_not_called() - @pytest.mark.skipif( - datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), - reason="Contact JumpStart team to fix flaky test.", - ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_script_filter( @@ -246,23 +240,25 @@ def test_list_jumpstart_models_script_filter( manifest_length = len(get_prototype_manifest()) vals = [True, False] for val in vals: - kwargs = {"filter": f"training_supported == {val}"} + kwargs = {"filter": And(f"training_supported == {val}", "model_type is open_weights")} list_jumpstart_models(**kwargs) assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() - kwargs = {"filter": f"training_supported != {val}"} + kwargs = {"filter": And(f"training_supported != {val}", "model_type is open_weights")} list_jumpstart_models(**kwargs) assert patched_read_s3_file.call_count == manifest_length assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() - - kwargs = {"filter": f"training_supported in {vals}", "list_versions": True} + kwargs = { + "filter": And(f"training_supported != {val}", "model_type is open_weights"), + "list_versions": True, + } assert list_jumpstart_models(**kwargs) == [ ("catboost-classification-model", "1.0.0"), ("huggingface-spc-bert-base-cased", "1.0.0"), @@ -279,7 +275,7 @@ def test_list_jumpstart_models_script_filter( patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() - kwargs = {"filter": f"training_supported not in {vals}"} + kwargs = {"filter": And(f"training_supported not in {vals}", "model_type is open_weights")} models = list_jumpstart_models(**kwargs) assert [] == models assert patched_read_s3_file.call_count == manifest_length @@ -518,10 +514,6 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): list_old_models=False, list_versions=True ) == list_jumpstart_models(list_versions=True) - @pytest.mark.skipif( - datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), - reason="Contact JumpStart team to fix flaky test.", - ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_vulnerable_models( @@ -547,12 +539,15 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): patched_read_s3_file.side_effect = vulnerable_inference_model_spec num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) - num_prop_specs = len(BASE_PROPRIETARY_MANIFEST) assert [] == list_jumpstart_models( - And("inference_vulnerable is false", "training_vulnerable is false") + And( + "inference_vulnerable is false", + "training_vulnerable is false", + "model_type is open_weights", + ) ) - assert patched_read_s3_file.call_count == num_specs + num_prop_specs + assert patched_read_s3_file.call_count == num_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -561,10 +556,14 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): patched_read_s3_file.side_effect = vulnerable_training_model_spec assert [] == list_jumpstart_models( - And("inference_vulnerable is false", "training_vulnerable is false") + And( + "inference_vulnerable is false", + "training_vulnerable is false", + "model_type is open_weights", + ) ) - assert patched_read_s3_file.call_count == num_specs + num_prop_specs + assert patched_read_s3_file.call_count == num_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -574,10 +573,6 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): assert patched_read_s3_file.call_count == 0 - @pytest.mark.skipif( - datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), - reason="Contact JumpStart team to fix flaky test.", - ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_deprecated_models( @@ -598,10 +593,11 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str: patched_read_s3_file.side_effect = deprecated_model_spec num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) - num_prop_specs = len(BASE_PROPRIETARY_MANIFEST) - assert [] == list_jumpstart_models("deprecated equals false") + assert [] == list_jumpstart_models( + And("deprecated equals false", "model_type is open_weights") + ) - assert patched_read_s3_file.call_count == num_specs + num_prop_specs + assert patched_read_s3_file.call_count == num_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock()