From 5aeaa359180e412c508a267f6edd3558be110b8c Mon Sep 17 00:00:00 2001 From: Haotian An Date: Thu, 22 Feb 2024 19:44:34 +0000 Subject: [PATCH 01/30] feat: add proprietary manifest/specs parsing add unittests for test_cache small refactoring address comments and more unittests fix linting and fix more tests fix: pylint feat: JumpStartModel class for prop models --- src/sagemaker/accept_types.py | 2 + src/sagemaker/base_predictor.py | 4 +- src/sagemaker/content_types.py | 2 + src/sagemaker/deserializers.py | 2 + src/sagemaker/jumpstart/accessors.py | 29 +- src/sagemaker/jumpstart/cache.py | 179 +++++++++--- src/sagemaker/jumpstart/constants.py | 14 +- src/sagemaker/jumpstart/enums.py | 11 + src/sagemaker/jumpstart/estimator.py | 10 +- src/sagemaker/jumpstart/factory/model.py | 18 +- src/sagemaker/jumpstart/model.py | 15 +- src/sagemaker/jumpstart/types.py | 39 ++- src/sagemaker/jumpstart/utils.py | 49 +++- src/sagemaker/serializers.py | 2 + .../jumpstart/model/test_jumpstart_model.py | 26 +- .../jumpstart/test_accept_types.py | 26 +- .../jumpstart/test_content_types.py | 25 +- .../jumpstart/test_deserializers.py | 25 +- .../jumpstart/test_default.py | 38 ++- .../hyperparameters/jumpstart/test_default.py | 10 +- .../jumpstart/test_validate.py | 33 ++- .../image_uris/jumpstart/test_common.py | 11 +- .../jumpstart/test_instance_types.py | 14 +- tests/unit/sagemaker/jumpstart/constants.py | 80 ++++++ .../jumpstart/estimator/test_estimator.py | 158 +++++------ .../estimator/test_sagemaker_config.py | 49 ++-- .../sagemaker/jumpstart/model/test_model.py | 216 +++++++++------ .../jumpstart/model/test_sagemaker_config.py | 49 ++-- .../sagemaker/jumpstart/test_accessors.py | 88 ++++++ .../sagemaker/jumpstart/test_artifacts.py | 14 +- tests/unit/sagemaker/jumpstart/test_cache.py | 258 ++++++++++++++++-- .../jumpstart/test_notebook_utils.py | 10 +- .../sagemaker/jumpstart/test_predictor.py | 46 +++- tests/unit/sagemaker/jumpstart/test_utils.py | 94 ++++--- tests/unit/sagemaker/jumpstart/utils.py | 49 +++- .../jumpstart/test_default.py | 19 +- .../model_uris/jumpstart/test_common.py | 11 +- .../jumpstart/test_resource_requirements.py | 15 +- .../script_uris/jumpstart/test_common.py | 16 +- .../serializers/jumpstart/test_serializers.py | 15 +- 40 files changed, 1360 insertions(+), 411 deletions(-) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index bf081365ab..cce0b653f8 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -16,6 +16,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -72,6 +73,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index 882cfafc39..76b83c25cd 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -58,7 +58,9 @@ from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME from sagemaker.lineage.context import EndpointContext -from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements +from sagemaker.compute_resource_requirements.resource_requirements import ( + ResourceRequirements, +) LOGGER = logging.getLogger("sagemaker") diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index e43e96be17..fa3d49fbba 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -16,6 +16,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -72,6 +73,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 706ae56bda..4cb596ca48 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -35,6 +35,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -95,6 +96,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index e03a13a7a3..482cfdeee7 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -18,6 +18,7 @@ from sagemaker.deprecations import deprecated from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import cache from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME @@ -197,7 +198,9 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None: @staticmethod def _get_manifest( - region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None + region: str = JUMPSTART_DEFAULT_REGION_NAME, + s3_client: Optional[boto3.client] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest. @@ -215,13 +218,19 @@ def _get_manifest( additional_kwargs.update({"s3_client": s3_client}) cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( - {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, region + {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, + region, ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) - return JumpStartModelsAccessor._cache.get_manifest() # type: ignore + return JumpStartModelsAccessor._cache.get_manifest(model_type) # type: ignore @staticmethod - def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader: + def get_model_header( + region: str, + model_id: str, + version: str, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + ) -> JumpStartModelHeader: """Returns model header from JumpStart models cache. Args: @@ -234,12 +243,18 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) return JumpStartModelsAccessor._cache.get_header( # type: ignore - model_id=model_id, semantic_version_str=version + model_id=model_id, + semantic_version_str=version, + model_type=model_type, ) @staticmethod def get_model_specs( - region: str, model_id: str, version: str, s3_client: Optional[boto3.client] = None + region: str, + model_id: str, + version: str, + s3_client: Optional[boto3.client] = None, + model_type=JumpStartModelType.OPEN_SOURCE, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. @@ -260,7 +275,7 @@ def get_model_specs( ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) return JumpStartModelsAccessor._cache.get_specs( # type: ignore - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version, model_type=model_type ) @staticmethod diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index e26d588167..888b1d07d1 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -25,9 +25,12 @@ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, MODEL_ID_LIST_WEB_URL, + MODEL_TYPE_TO_MANIFEST_MAP, + MODEL_TYPE_TO_SPECS_MAP, ) from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( @@ -44,6 +47,7 @@ JumpStartS3FileType, JumpStartVersionedModelId, ) +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import utils from sagemaker.utilities.cache import LRUCache @@ -68,6 +72,7 @@ def __init__( JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, manifest_file_s3_key: str = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + proprietary_manifest_s3_key: str = JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, s3_client: Optional[boto3.client] = None, @@ -100,14 +105,26 @@ def __init__( expiration_horizon=s3_cache_expiration_horizon, retrieval_function=self._retrieval_function, ) - self._model_id_semantic_version_manifest_key_cache = LRUCache[ + self._open_source_model_id_manifest_key_cache = LRUCache[ JumpStartVersionedModelId, JumpStartVersionedModelId ]( max_cache_items=max_semantic_version_cache_items, expiration_horizon=semantic_version_cache_expiration_horizon, - retrieval_function=self._get_manifest_key_from_model_id_semantic_version, + retrieval_function=self._get_open_source_manifest_key_from_model_id, + ) + self._proprietary_model_id_manifest_key_cache = LRUCache[ + JumpStartVersionedModelId, JumpStartVersionedModelId + ]( + max_cache_items=max_semantic_version_cache_items, + expiration_horizon=semantic_version_cache_expiration_horizon, + retrieval_function=self._get_proprietary_manifest_key_from_model_id, ) self._manifest_file_s3_key = manifest_file_s3_key + self._proprietary_manifest_s3_key = proprietary_manifest_s3_key + self._manifest_file_s3_map = { + JumpStartModelType.OPEN_SOURCE: self._manifest_file_s3_key, + JumpStartModelType.PROPRIETARY: self._proprietary_manifest_s3_key, + } self.s3_bucket_name = ( utils.get_jumpstart_content_bucket(self._region) if s3_bucket_name is None @@ -129,15 +146,44 @@ def get_region(self) -> str: """Return region for cache.""" return self._region - def set_manifest_file_s3_key(self, key: str) -> None: - """Set manifest file s3 key. Clears cache after new key is set.""" - if key != self._manifest_file_s3_key: - self._manifest_file_s3_key = key + def set_manifest_file_s3_key( + self, + key: str, + file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_SOURCE_MANIFEST, + ) -> None: + """Set manifest file s3 key. Clears cache after new key is set. + + Raises: + ValueError: if the file type is not recognized + """ + file_mapping = { + JumpStartS3FileType.OPEN_SOURCE_MANIFEST: self._manifest_file_s3_key, + JumpStartS3FileType.PROPRIETARY_MANIFEST: self._proprietary_manifest_s3_key, + } + property_name = file_mapping.get(file_type) + if not property_name: + raise ValueError( + f"Bad value when setting manifest '{file_type}': must be in" + f"{JumpStartS3FileType.OPEN_SOURCE_MANIFEST}" + f"{JumpStartS3FileType.PROPRIETARY_MANIFEST}" + ) + if key != property_name: + setattr(self, property_name, key) self.clear() - def get_manifest_file_s3_key(self) -> str: + def get_manifest_file_s3_key( + self, file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_SOURCE_MANIFEST + ) -> str: """Return manifest file s3 key for cache.""" - return self._manifest_file_s3_key + if file_type == JumpStartS3FileType.OPEN_SOURCE_MANIFEST: + return self._manifest_file_s3_key + if file_type == JumpStartS3FileType.PROPRIETARY_MANIFEST: + return self._proprietary_manifest_s3_key + raise ValueError( + f"Bad value when getting manifest '{file_type}':" + f"must be in {JumpStartS3FileType.OPEN_SOURCE_MANIFEST}" + f"{JumpStartS3FileType.PROPRIETARY_MANIFEST}" + ) def set_s3_bucket_name(self, s3_bucket_name: str) -> None: """Set s3 bucket used for cache.""" @@ -149,10 +195,11 @@ def get_bucket(self) -> str: """Return bucket used for cache.""" return self.s3_bucket_name - def _get_manifest_key_from_model_id_semantic_version( + def _model_id_retrieval_function( self, key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 + model_type: JumpStartModelType ) -> JumpStartVersionedModelId: """Return model ID and version in manifest that matches semantic version/id. @@ -164,6 +211,8 @@ def _get_manifest_key_from_model_id_semantic_version( key (JumpStartVersionedModelId): Key for which to fetch versioned model ID. value (Optional[JumpStartVersionedModelId]): Unused variable for current value of old cached model ID/version. + UseSematicVersion (bool): boolean value to indicate whether the model versions follow + sematic versioning. Raises: KeyError: If the semantic version is not found in the manifest, or is found but @@ -172,14 +221,14 @@ def _get_manifest_key_from_model_id_semantic_version( model_id, version = key.model_id, key.version + sm_version = utils.get_sagemaker_version() manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + JumpStartCachedS3ContentKey( + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) )[0].formatted_content - sm_version = utils.get_sagemaker_version() - versions_compatible_with_sagemaker = [ - Version(header.version) + header.version for header in manifest.values() # type: ignore if header.model_id == model_id and Version(header.min_version) <= Version(sm_version) ] @@ -242,6 +291,26 @@ def _get_manifest_key_from_model_id_semantic_version( raise KeyError(error_msg) + def _get_open_source_manifest_key_from_model_id( + self, + key: JumpStartVersionedModelId, + value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 + ) -> JumpStartVersionedModelId: + """Get open source manifest key from model id.""" + return self._model_id_retrieval_function( + key, value, model_type=JumpStartModelType.OPEN_SOURCE + ) + + def _get_proprietary_manifest_key_from_model_id( + self, + key: JumpStartVersionedModelId, + value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 + ) -> JumpStartVersionedModelId: + """Get proprietary manifest key from model id.""" + return self._model_id_retrieval_function( + key, value, model_type=JumpStartModelType.PROPRIETARY + ) + def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], str]: """Returns json file from s3, along with its etag.""" response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=key) @@ -286,11 +355,11 @@ def _get_json_file_from_local_override( filetype: JumpStartS3FileType ) -> Union[dict, list]: """Reads json file from local filesystem and returns data.""" - if filetype == JumpStartS3FileType.MANIFEST: + if filetype == JumpStartS3FileType.OPEN_SOURCE_MANIFEST: metadata_local_root = ( os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE] ) - elif filetype == JumpStartS3FileType.SPECS: + elif filetype == JumpStartS3FileType.OPEN_SOURCE_SPECS: metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE] else: raise ValueError(f"Unsupported file type for local override: {filetype}") @@ -318,8 +387,10 @@ def _retrieval_function( """ file_type, s3_key = key.file_type, key.s3_key - - if file_type == JumpStartS3FileType.MANIFEST: + if file_type in { + JumpStartS3FileType.OPEN_SOURCE_MANIFEST, + JumpStartS3FileType.PROPRIETARY_MANIFEST, + }: if value is not None and not self._is_local_metadata_mode(): etag = self._get_json_md5_hash(s3_key) if etag == value.md5_hash: @@ -329,27 +400,38 @@ def _retrieval_function( formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) - if file_type == JumpStartS3FileType.SPECS: + if file_type in { + JumpStartS3FileType.OPEN_SOURCE_SPECS, + JumpStartS3FileType.PROPRIETARY_SPECS, + }: formatted_body, _ = self._get_json_file(s3_key, file_type) model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) - return JumpStartCachedS3ContentValue( - formatted_content=model_specs - ) + return JumpStartCachedS3ContentValue(formatted_content=model_specs) raise ValueError( - f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}" + f"Bad value for key '{key}': must be in" + f"{JumpStartS3FileType.OPEN_SOURCE_MANIFEST, JumpStartS3FileType.OPEN_SOURCE_SPECS}" + f"{JumpStartS3FileType.PROPRIETARY_SPECS, JumpStartS3FileType.PROPRIETARY_MANIFEST}" ) - def get_manifest(self) -> List[JumpStartModelHeader]: + def get_manifest( + self, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest.""" - manifest_dict = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + JumpStartCachedS3ContentKey( + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) )[0].formatted_content manifest = list(manifest_dict.values()) # type: ignore return manifest - def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader: + def get_header( + self, + model_id: str, + semantic_version_str: str, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + ) -> JumpStartModelHeader: """Return header for a given JumpStart model ID and semantic version. Args: @@ -357,8 +439,9 @@ def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModel semantic_version_str (str): The semantic version for which to get a header. """ - - return self._get_header_impl(model_id, semantic_version_str=semantic_version_str) + return self._get_header_impl( + model_id, semantic_version_str=semantic_version_str, model_type=model_type + ) def _select_version( self, @@ -391,6 +474,7 @@ def _get_header_impl( model_id: str, semantic_version_str: str, attempt: int = 0, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE ) -> JumpStartModelHeader: """Lower-level function to return header. @@ -402,14 +486,21 @@ def _get_header_impl( header. attempt (int): attempt number at retrieving a header. """ + if model_type == JumpStartModelType.OPEN_SOURCE: + versioned_model_id = self._open_source_model_id_manifest_key_cache.get( + JumpStartVersionedModelId(model_id, semantic_version_str) + )[0] - versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get( - JumpStartVersionedModelId(model_id, semantic_version_str) - )[0] + elif model_type == JumpStartModelType.PROPRIETARY: + versioned_model_id = self._proprietary_model_id_manifest_key_cache.get( + JumpStartVersionedModelId(model_id, semantic_version_str) + )[0] manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + JumpStartCachedS3ContentKey( + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) )[0].formatted_content + try: header = manifest[versioned_model_id] # type: ignore return header @@ -419,7 +510,12 @@ def _get_header_impl( self.clear() return self._get_header_impl(model_id, semantic_version_str, attempt + 1) - def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelSpecs: + def get_specs( + self, + model_id: str, + version_str: str, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE + ) -> JumpStartModelSpecs: """Return specs for a given JumpStart model ID and semantic version. Args: @@ -427,18 +523,18 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS semantic_version_str (str): The semantic version for which to get specs. """ - - header = self.get_header(model_id, semantic_version_str) + header = self.get_header(model_id, version_str, model_type) spec_key = header.spec_key specs, cache_hit = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key) + JumpStartCachedS3ContentKey( + MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key + ) ) - if not cache_hit and "*" in semantic_version_str: + + if not cache_hit and "*" in version_str: JUMPSTART_LOGGER.warning( get_wildcard_model_version_msg( - header.model_id, - semantic_version_str, - header.version + header.model_id, version_str, header.version ) ) return specs.formatted_content @@ -446,4 +542,5 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS def clear(self) -> None: """Clears the model ID/version and s3 cache.""" self._s3_cache.clear() - self._model_id_semantic_version_manifest_key_cache.clear() + self._open_source_model_id_manifest_key_cache.clear() + self._proprietary_model_id_manifest_key_cache.clear() diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 2e655ac285..114c54963b 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -22,8 +22,9 @@ SerializerType, DeserializerType, MIMEType, + JumpStartModelType, ) -from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo +from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo, JumpStartS3FileType from sagemaker.base_serializers import ( BaseSerializer, CSVSerializer, @@ -169,6 +170,7 @@ JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" +JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py" @@ -213,6 +215,16 @@ DeserializerType.JSON: JSONDeserializer, } +MODEL_TYPE_TO_MANIFEST_MAP: Dict[Type[JumpStartModelType], Type[JumpStartS3FileType]] = { + JumpStartModelType.OPEN_SOURCE: JumpStartS3FileType.OPEN_SOURCE_MANIFEST, + JumpStartModelType.PROPRIETARY: JumpStartS3FileType.PROPRIETARY_MANIFEST, +} + +MODEL_TYPE_TO_SPECS_MAP: Dict[Type[JumpStartModelType], Type[JumpStartS3FileType]] = { + JumpStartModelType.OPEN_SOURCE: JumpStartS3FileType.OPEN_SOURCE_SPECS, + JumpStartModelType.PROPRIETARY: JumpStartS3FileType.PROPRIETARY_SPECS, +} + MODEL_ID_LIST_WEB_URL = "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html" JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart") diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index e33daca046..2188294f06 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -34,6 +34,17 @@ class ModelFramework(str, Enum): SKLEARN = "sklearn" +class JumpStartModelType(str, Enum): + """Enum class for JumpStart model type. + + Open source model refers to JumpStart owned community models. + Proprietary model refers to external provider owned Marketplace models. + """ + + OPEN_SOURCE = "opensource" + PROPRIETARY = "proprietary" + + class VariableScope(str, Enum): """Possible value of the ``scope`` attribute for a hyperparameter or environment variable. diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 24105c4369..d245bf5aca 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -35,7 +35,7 @@ from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job from sagemaker.jumpstart.utils import ( - is_valid_model_id, + validate_model_id_and_get_type, resolve_model_sagemaker_config_field, ) from sagemaker.utils import stringify_object, format_tags, Tags @@ -504,8 +504,8 @@ def __init__( ValueError: If the model ID is not recognized by JumpStart. """ - def _is_valid_model_id_hook(): - return is_valid_model_id( + def _validate_model_id_and_get_type_hook(): + return validate_model_id_and_get_type( model_id=model_id, model_version=model_version, region=region, @@ -513,9 +513,9 @@ def _is_valid_model_id_hook(): sagemaker_session=sagemaker_session, ) - if not _is_valid_model_id_hook(): + if not _validate_model_id_and_get_type_hook(): JumpStartModelsAccessor.reset_cache() - if not _is_valid_model_id_hook(): + if not _validate_model_id_and_get_type_hook(): raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) estimator_init_kwargs = get_init_kwargs( diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 64e4727116..4fda2cbf4c 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -37,7 +37,7 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.types import ( JumpStartModelDeployKwargs, JumpStartModelInitKwargs, @@ -71,6 +71,7 @@ def get_default_predictor( tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, sagemaker_session: Session, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. @@ -92,6 +93,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.deserializer = deserializers.retrieve_default( model_id=model_id, @@ -100,6 +102,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.accept = accept_types.retrieve_default( model_id=model_id, @@ -108,6 +111,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.content_type = content_types.retrieve_default( model_id=model_id, @@ -116,6 +120,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) return predictor @@ -199,7 +204,14 @@ def _add_instance_type_to_kwargs( def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: - """Sets image uri based on default or override, returns full kwargs.""" + """ + Sets image uri based on default or override, returns full kwargs. + Uses placeholder image uri for JumpStart proprietary models that uses ModelPackages + """ + + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.image_uri = "" + return kwargs kwargs.image_uri = kwargs.image_uri or image_uris.retrieve( region=kwargs.region, @@ -654,6 +666,7 @@ def get_init_kwargs( model_id: str, model_from_estimator: bool = False, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, instance_type: Optional[str] = None, @@ -685,6 +698,7 @@ def get_init_kwargs( model_init_kwargs: JumpStartModelInitKwargs = JumpStartModelInitKwargs( model_id=model_id, model_version=model_version, + model_type=model_type, instance_type=instance_type, region=region, image_uri=image_uri, diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 1742f860e4..837291b813 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -30,7 +30,7 @@ get_register_kwargs, ) from sagemaker.jumpstart.types import JumpStartSerializablePayload -from sagemaker.jumpstart.utils import is_valid_model_id +from sagemaker.jumpstart.utils import validate_model_id_and_get_type from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model import ( Model, @@ -270,8 +270,8 @@ def __init__( ValueError: If the model ID is not recognized by JumpStart. """ - def _is_valid_model_id_hook(): - return is_valid_model_id( + def _validate_model_id_and_type(): + return validate_model_id_and_get_type( model_id=model_id, model_version=model_version, region=region, @@ -279,16 +279,18 @@ def _is_valid_model_id_hook(): sagemaker_session=sagemaker_session, ) - if not _is_valid_model_id_hook(): + self._model_type = _validate_model_id_and_type() + if not self._model_type: JumpStartModelsAccessor.reset_cache() - if not _is_valid_model_id_hook(): + self._model_type = _validate_model_id_and_type() + if not self._model_type: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) self._model_data_is_set = model_data is not None - model_init_kwargs = get_init_kwargs( model_id=model_id, model_from_estimator=False, + model_type=self._model_type, model_version=model_version, instance_type=instance_type, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -603,6 +605,7 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, + model_type=self._model_type, ) # If a predictor class was passed, do not mutate predictor diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 49d3e295c5..21bdfa938a 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -19,6 +19,7 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session from sagemaker.workflow.entities import PipelineVariable @@ -102,8 +103,10 @@ def __repr__(self) -> str: class JumpStartS3FileType(str, Enum): """Type of files published in JumpStart S3 distribution buckets.""" - MANIFEST = "manifest" - SPECS = "specs" + OPEN_SOURCE_MANIFEST = "manifest" + OPEN_SOURCE_SPECS = "specs" + PROPRIETARY_MANIFEST = "proptietary_manifest" + PROPRIETARY_SPECS = "proprietary_specs" class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): @@ -782,29 +785,31 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of spec. """ self.model_id: str = json_obj["model_id"] - self.url: str = json_obj["url"] + self.url: str = json_obj.get("url", "") self.version: str = json_obj["version"] self.min_sdk_version: str = json_obj["min_sdk_version"] - self.incremental_training_supported: bool = bool(json_obj["incremental_training_supported"]) + self.incremental_training_supported: bool = bool( + json_obj.get("incremental_training_supported", False) + ) self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) if "hosting_ecr_specs" in json_obj else None ) - self.hosting_artifact_key: str = json_obj["hosting_artifact_key"] - self.hosting_script_key: str = json_obj["hosting_script_key"] - self.training_supported: bool = bool(json_obj["training_supported"]) + self.hosting_artifact_key: Optional[str] = json_obj.get("hosting_artifact_key") + self.hosting_script_key: Optional[str] = json_obj.get("hosting_script_key") + self.training_supported: Optional[bool] = bool(json_obj.get("training_supported", False)) self.inference_environment_variables = [ JumpStartEnvironmentVariable(env_variable) - for env_variable in json_obj["inference_environment_variables"] + for env_variable in json_obj.get("inference_environment_variables", []) ] - self.inference_vulnerable: bool = bool(json_obj["inference_vulnerable"]) - self.inference_dependencies: List[str] = json_obj["inference_dependencies"] - self.inference_vulnerabilities: List[str] = json_obj["inference_vulnerabilities"] - self.training_vulnerable: bool = bool(json_obj["training_vulnerable"]) - self.training_dependencies: List[str] = json_obj["training_dependencies"] - self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"] - self.deprecated: bool = bool(json_obj["deprecated"]) + self.inference_vulnerable: bool = bool(json_obj.get("inference_vulnerable", False)) + self.inference_dependencies: List[str] = json_obj.get("inference_dependencies", []) + self.inference_vulnerabilities: List[str] = json_obj.get("inference_vulnerabilities", []) + self.training_vulnerable: bool = bool(json_obj.get("training_vulnerable", False)) + self.training_dependencies: List[str] = json_obj.get("training_dependencies", []) + self.training_vulnerabilities: List[str] = json_obj.get("training_vulnerabilities", []) + self.deprecated: bool = bool(json_obj.get("deprecated", False)) self.deprecated_message: Optional[str] = json_obj.get("deprecated_message") self.deprecate_warn_message: Optional[str] = json_obj.get("deprecate_warn_message") self.usage_info_message: Optional[str] = json_obj.get("usage_info_message") @@ -1026,6 +1031,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "model_type", "instance_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", @@ -1056,6 +1062,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "instance_type", "model_id", "model_version", + "model_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", "region", @@ -1067,6 +1074,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -1096,6 +1104,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.model_type = model_type self.instance_type = instance_type self.region = region self.image_uri = image_uri diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 2621422811..6094b8cf53 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging import os -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Set, Optional, Tuple, Union from urllib.parse import urlparse import boto3 from packaging.version import Version @@ -26,6 +26,7 @@ TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, TRAINING_JOB_ROLE_ARN_PATH, ) + from sagemaker.jumpstart import constants, enums from sagemaker.jumpstart import accessors from sagemaker.s3 import parse_s3_url @@ -572,12 +573,16 @@ def verify_model_region_and_return_specs( "JumpStart models only support scopes: " f"{', '.join(constants.SUPPORTED_JUMPSTART_SCOPES)}." ) + model_type = validate_model_id_and_get_type( + model_id=model_id, region=region, model_version=version, script=scope + ) model_specs = accessors.JumpStartModelsAccessor.get_model_specs( # type: ignore region=region, model_id=model_id, version=version, s3_client=sagemaker_session.s3_client, + model_type=model_type, ) if ( @@ -732,36 +737,52 @@ def resolve_estimator_sagemaker_config_field( return field_val -def is_valid_model_id( +def validate_model_id_and_get_type( model_id: Optional[str], region: Optional[str] = None, model_version: Optional[str] = None, script: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> bool: - """Returns True if the model ID is supported for the given script. +) -> Optional[enums.JumpStartModelType]: + """Returns model type if the model ID is supported for the given script. Raises: ValueError: If the script is not supported by JumpStart. """ + + def _get_model_type( + model_id: str, + open_source_models: Set[str], + proprietary_models: Set[str], + script: enums.JumpStartScriptScope, + ) -> Optional[enums.JumpStartModelType]: + if model_id in open_source_models: + return enums.JumpStartModelType.OPEN_SOURCE + if model_id in proprietary_models: + if script == enums.JumpStartScriptScope.INFERENCE: + return enums.JumpStartModelType.PROPRIETARY + raise ValueError(f"Unsupported script for Marketplace models: {script}") + return None + if model_id in {None, ""}: - return False + return None if not isinstance(model_id, str): - return False + return None s3_client = sagemaker_session.s3_client if sagemaker_session else None region = region or constants.JUMPSTART_DEFAULT_REGION_NAME model_version = model_version or "*" - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( - region=region, s3_client=s3_client + region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.OPEN_SOURCE + ) + open_source_model_id_set = {model.model_id for model in models_manifest_list} + + proprietary_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.PROPRIETARY ) - model_id_set = {model.model_id for model in models_manifest_list} - if script == enums.JumpStartScriptScope.INFERENCE: - return model_id in model_id_set - if script == enums.JumpStartScriptScope.TRAINING: - return model_id in model_id_set - raise ValueError(f"Unsupported script: {script}") + + proprietary_model_id_set = {model.model_id for model in proprietary_manifest_list} + return _get_model_type(model_id, open_source_model_id_set, proprietary_model_id_set, script) def get_jumpstart_model_id_version_from_resource_arn( diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index fc76c0fa76..b056ff593c 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -34,6 +34,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -91,6 +92,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 24050807cc..26f811322f 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -39,7 +39,7 @@ "us-east-2", } -MAX_INIT_TIME_SECONDS = 5 +MAX_INIT_TIME_SECONDS = 15 GATED_INFERENCE_MODEL_PACKAGE_SUPPORTED_REGIONS = { "us-west-2", @@ -237,8 +237,6 @@ def test_instatiating_model(mock_warning_logger, setup): assert elapsed_time <= MAX_INIT_TIME_SECONDS - mock_warning_logger.assert_called_once() - def test_jumpstart_model_register(setup): model_id = "huggingface-txt2img-conflictx-complex-lineart" @@ -262,3 +260,25 @@ def test_jumpstart_model_register(setup): response = predictor.predict("hello world!") assert response is not None + + +@pytest.mark.skipif( + True, + reason="Only enable if test account is subscribed to the proprietary model", +) +def test_proprietary_jumpstart_model(setup): + + model_id = "ai21-jurassic-2-light" + + model = JumpStartModel( + model_id=model_id, + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + predictor = model.deploy() + payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1} + + response = predictor.predict(payload) + + assert response is not None diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 28211d06f1..4272684d35 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -17,21 +17,27 @@ from sagemaker import accept_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec + mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_accept_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -45,18 +51,26 @@ def test_jumpstart_default_accept_types( assert default_accept_type == "application/json" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_supported_accept_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -73,5 +87,9 @@ def test_jumpstart_supported_accept_types( ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index 4b2db7d7f4..50250d8f3b 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -17,6 +17,7 @@ from sagemaker import content_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec @@ -24,14 +25,18 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_content_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -45,18 +50,26 @@ def test_jumpstart_default_content_types( assert default_content_type == "application/x-text" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_supported_content_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -72,5 +85,9 @@ def test_jumpstart_supported_content_types( ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index 9d6e2f21de..3e917d55c7 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -18,6 +18,7 @@ from sagemaker import base_deserializers, deserializers from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec @@ -26,14 +27,18 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_deserializers( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -47,18 +52,26 @@ def test_jumpstart_default_deserializers( assert isinstance(default_deserializer, base_deserializers.JSONDeserializer) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_deserializer_options( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -79,5 +92,9 @@ def test_jumpstart_deserializer_options( ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index acd8d19923..f1a5176263 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -18,17 +18,23 @@ import pytest from sagemaker import environment_variables +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec + mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_default_environment_variables(patched_get_model_specs): +def test_jumpstart_default_environment_variables( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -48,7 +54,11 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() @@ -68,7 +78,11 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() @@ -98,10 +112,14 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_sdk_environment_variables(patched_get_model_specs): +def test_jumpstart_sdk_environment_variables( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -122,7 +140,11 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() @@ -143,7 +165,11 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index eebc079164..3ac10c9109 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -18,6 +18,7 @@ import pytest from sagemaker import hyperparameters +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec @@ -26,10 +27,14 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_default_hyperparameters(patched_get_model_specs): +def test_jumpstart_default_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -47,6 +52,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): model_id=model_id, version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() @@ -64,6 +70,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): model_id=model_id, version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() @@ -89,6 +96,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): model_id=model_id, version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 0054ed9dbd..cf7a321b79 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -17,7 +17,7 @@ import pytest import boto3 from sagemaker import hyperparameters -from sagemaker.jumpstart.enums import HyperparameterValidationMode +from sagemaker.jumpstart.enums import HyperparameterValidationMode, JumpStartModelType from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError from sagemaker.jumpstart.types import JumpStartHyperparameter @@ -27,8 +27,11 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_validate_provided_hyperparameters(patched_get_model_specs): +def test_jumpstart_validate_provided_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): def add_options_to_hyperparameter(*largs, **kwargs): spec = get_spec_from_base_spec(*largs, **kwargs) spec.hyperparameters.extend( @@ -109,6 +112,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): return spec patched_get_model_specs.side_effect = add_options_to_hyperparameter + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -140,6 +144,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() @@ -398,8 +403,11 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_validate_algorithm_hyperparameters(patched_get_model_specs): +def test_jumpstart_validate_algorithm_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): def add_options_to_hyperparameter(*largs, **kwargs): spec = get_spec_from_base_spec(*largs, **kwargs) spec.hyperparameters.append( @@ -416,6 +424,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): return spec patched_get_model_specs.side_effect = add_options_to_hyperparameter + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -437,7 +446,11 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() @@ -464,10 +477,14 @@ def add_options_to_hyperparameter(*largs, **kwargs): assert str(e.value) == "Cannot find algorithm hyperparameter for 'adam-learning-rate'." +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): +def test_jumpstart_validate_all_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -491,7 +508,11 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 8a41891280..912ef6fb49 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -18,19 +18,24 @@ from sagemaker import image_uris from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.image_uris.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_common_image_uri( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -49,6 +54,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -69,6 +75,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -89,6 +96,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -109,6 +117,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index bed2e50674..748cfca0e2 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -18,14 +18,17 @@ import pytest from sagemaker import instance_types +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_instance_types(patched_get_model_specs): +def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_model_id_and_get_type): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, model_version = "huggingface-eqa-bert-base-cased", "*" region = "us-west-2" @@ -47,6 +50,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() @@ -65,6 +69,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() @@ -89,6 +94,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() @@ -111,7 +117,11 @@ def test_jumpstart_instance_types(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index a3c4c747f7..0abe51749b 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -6289,3 +6289,83 @@ "imagenet-inception-v3-classification-4/specs_v3.0.0.json", }, ] + +BASE_PROPRIETARY_HEADER = { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], +} + +BASE_PROPRIETARY_MANIFEST = [ + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "lighton-mini-instruct40b", + "version": "v1.0", + "min_version": "2.0.0", + "spec_key": "proprietary-models/lighton-mini-instruct40b/proprietary_specs_v1.0.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "ai21-paraphrase", + "version": "1.0.005", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", + "search_keywords": ["Text2Text", "Generation"], + }, +] + +BASE_PROPRIETARY_SPEC = { + "model_id": "ai21-jurassic-2-light", + "version": "2.0.004", + "min_sdk_version": "2.999.0", + "listing_id": "prodview-roz6zicyvi666", + "product_id": "1bd680a0-f29b-479d-91c3-9899743021cf", + "model_subscription_link": "https://aws.amazon.com/marketplace/ai/procurement?productId=1bd680a0", + "hosting_notebook_key": "pmm-notebooks/pmm-notebook-ai21-jurassic-2-light.ipynb", + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 600, + }, + "default_payloads": { + "Shakespeare": { + "content_type": "application/json", + "prompt_key": "prompt", + "output_keys": {"generated_text": "[0].completions[0].data.text"}, + "body": {"prompt": "To be, or", "maxTokens": 1, "temperature": 0}, + } + }, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge"], + "hosting_model_package_arns": { + "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/j2-light-v2-0-004", + "us-east-2": "arn:aws:sagemaker:us-east-2:057799348421:model-package/j2-light-v2-0-004", + "us-west-1": "arn:aws:sagemaker:us-west-1:382657785993:model-package/j2-light-v2-0-004", + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/j2-light-v2-0-004", + "ca-central-1": "arn:aws:sagemaker:ca-central-1:470592106596:model-package/j2-light-v2-0-004", + "eu-central-1": "arn:aws:sagemaker:eu-central-1:446921602837:model-package/j2-light-v2-0-004", + "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/j2-light-v2-0-004", + "eu-west-2": "arn:aws:sagemaker:eu-west-2:856760150666:model-package/j2-light-v2-0-004", + "eu-west-3": "arn:aws:sagemaker:eu-west-3:843114510376:model-package/j2-light-v2-0-004", + "eu-north-1": "arn:aws:sagemaker:eu-north-1:136758871317:model-package/j2-light-v2-0-004", + "ap-southeast-1": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/j2-light-v2-0-004", + "ap-southeast-2": "arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/j2-light-v2-0-004", + "ap-northeast-2": "arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/j2-light-v2-0-004", + "ap-northeast-1": "arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/j2-light-v2-0-004", + "ap-south-1": "arn:aws:sagemaker:ap-south-1:077584701553:model-package/j2-light-v2-0-004", + "sa-east-1": "arn:aws:sagemaker:sa-east-1:270155090741:model-package/j2-light-v2-0-004", + }, +} diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 4dc35b65ca..c55f1779cd 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -31,7 +31,7 @@ _retrieve_default_training_metric_definitions, ) from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType from sagemaker.jumpstart.estimator import JumpStartEstimator @@ -60,9 +60,10 @@ class EstimatorTest(unittest.TestCase): @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -75,14 +76,15 @@ def test_non_prepacked( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, + mock_get_model_type: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_jumpstart_model_factory_logger: mock.Mock, mock_jumpstart_estimator_factory_logger: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE mock_sagemaker_timestamp.return_value = "9876" @@ -92,6 +94,8 @@ def test_non_prepacked( mock_get_model_specs.side_effect = get_special_model_spec + mock_get_model_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session @@ -182,7 +186,7 @@ def test_non_prepacked( ], ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -199,11 +203,11 @@ def test_prepacked( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model-prepacked", "*" @@ -282,7 +286,7 @@ def test_prepacked( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -299,14 +303,14 @@ def test_gated_model_s3_uri( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_timestamp.return_value = "8675309" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-gated-artifact-trainable-model", "*" @@ -418,7 +422,7 @@ def test_gated_model_s3_uri( "sagemaker.jumpstart.artifacts.environment_variables.get_jumpstart_gated_content_bucket" ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -435,7 +439,7 @@ def test_gated_model_non_model_package_s3_uri( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, mock_get_jumpstart_gated_content_bucket: mock.Mock, ): @@ -444,7 +448,7 @@ def test_gated_model_non_model_package_s3_uri( mock_get_jumpstart_gated_content_bucket.return_value = "top-secret-private-models-bucket" mock_timestamp.return_value = "8675309" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-gated-artifact-non-model-package-trainable-model", "*" @@ -566,7 +570,7 @@ def test_gated_model_non_model_package_s3_uri( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -583,15 +587,13 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_timestamp.return_value = "8675309" - mock_is_valid_model_id.return_value = True - model_id, _ = "js-gated-artifact-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -599,6 +601,8 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + with pytest.raises(ValueError) as e: JumpStartEstimator(model_id=model_id, region="eu-north-1") @@ -608,7 +612,7 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( "us-west-2, us-east-1, eu-west-1, ap-southeast-1." ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -621,10 +625,10 @@ def test_deprecated( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "deprecated_model", "*" @@ -642,7 +646,7 @@ def test_deprecated( JumpStartEstimator(model_id=model_id, tolerate_deprecated_model=True).fit(channels).deploy() - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -655,9 +659,9 @@ def test_vulnerable( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "vulnerable_model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -758,7 +762,7 @@ def test_estimator_use_kwargs(self): @mock.patch("sagemaker.jumpstart.factory.estimator.metric_definitions_utils.retrieve_default") @mock.patch("sagemaker.jumpstart.factory.estimator.environment_variables.retrieve_default") @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -775,7 +779,7 @@ def evaluate_estimator_workflow_with_kwargs( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, mock_retrieve_default_environment_variables: mock.Mock, mock_retrieve_metric_definitions: mock.Mock, @@ -806,7 +810,7 @@ def evaluate_estimator_workflow_with_kwargs( mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, model_version = "js-trainable-model", "*" @@ -908,16 +912,16 @@ def evaluate_estimator_workflow_with_kwargs( mock_estimator_deploy.assert_called_once_with(**expected_deploy_kwargs) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_tags_disabled( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model-prepacked", "*" @@ -947,16 +951,16 @@ def test_jumpstart_estimator_tags_disabled( [{"Key": "blah", "Value": "blahagain"}], ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_tags( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model-prepacked", "*" @@ -989,18 +993,18 @@ def test_jumpstart_estimator_tags( @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_attach_no_model_id_happy_case( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, get_model_id_version_from_training_job: mock.Mock, mock_attach: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE get_model_id_version_from_training_job.return_value = ( "js-trainable-model-prepacked", @@ -1032,18 +1036,18 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_attach_no_model_id_sad_case( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, get_model_id_version_from_training_job: mock.Mock, mock_attach: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE get_model_id_version_from_training_job.side_effect = ValueError() @@ -1115,22 +1119,22 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): @mock.patch("sagemaker.jumpstart.estimator.get_init_kwargs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") - def test_is_valid_model_id( + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") + def test_validate_model_id_and_get_type( self, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE JumpStartEstimator(model_id="valid_model_id") - mock_is_valid_model_id.return_value = False + mock_validate_model_id_and_get_type.return_value = False with pytest.raises(ValueError): JumpStartEstimator(model_id="invalid_model_id") @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1147,14 +1151,14 @@ def test_no_predictor_returns_default_predictor( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_get_default_predictor.return_value = default_predictor_with_presets - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model-prepacked", "*" @@ -1189,7 +1193,7 @@ def test_no_predictor_returns_default_predictor( self.assertEqual(predictor, default_predictor_with_presets) @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1206,14 +1210,14 @@ def test_no_predictor_yes_async_inference_config( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_get_default_predictor.return_value = default_predictor_with_presets - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model-prepacked", "*" @@ -1239,7 +1243,7 @@ def test_no_predictor_yes_async_inference_config( self.assertEqual(type(predictor), Predictor) @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1256,14 +1260,14 @@ def test_yes_predictor_returns_unmodified_predictor( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_get_default_predictor.return_value = default_predictor_with_presets - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model-prepacked", "*" @@ -1289,7 +1293,7 @@ def test_yes_predictor_returns_unmodified_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -1310,9 +1314,9 @@ def test_incremental_training_with_unsupported_model_logs_warning( mock_session_model: mock.Mock, mock_logger_warning: mock.Mock, mock_supports_incremental_training: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE mock_estimator_deploy.return_value = default_predictor @@ -1343,7 +1347,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( sagemaker_session=sagemaker_session, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -1364,9 +1368,9 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( mock_session_model: mock.Mock, mock_logger_warning: mock.Mock, mock_supports_incremental_training: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE mock_estimator_deploy.return_value = default_predictor @@ -1395,7 +1399,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1412,10 +1416,10 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE mock_sagemaker_timestamp.return_value = "3456" @@ -1456,7 +1460,7 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.p4de.24xlarge" @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1473,10 +1477,10 @@ def test_training_passes_role_to_deploy( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE mock_sagemaker_timestamp.return_value = "3456" @@ -1533,7 +1537,7 @@ def test_training_passes_role_to_deploy( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.DEFAULT_JUMPSTART_SAGEMAKER_SESSION", sagemaker_session ) @@ -1553,10 +1557,10 @@ def test_training_passes_session_to_deploy( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE mock_sagemaker_timestamp.return_value = "3456" @@ -1611,7 +1615,7 @@ def test_training_passes_session_to_deploy( ], ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_init_kwargs") @@ -1627,11 +1631,11 @@ def test_model_id_not_found_refeshes_cache_training( mock_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.side_effect = [False, False] + mock_validate_model_id_and_get_type.side_effect = [False, False] model_id, _ = "js-trainable-model", "*" @@ -1647,7 +1651,7 @@ def test_model_id_not_found_refeshes_cache_training( ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -1666,16 +1670,16 @@ def test_model_id_not_found_refeshes_cache_training( ] ) - mock_is_valid_model_id.reset_mock() + mock_validate_model_id_and_get_type.reset_mock() mock_reset_cache.reset_mock() - mock_is_valid_model_id.side_effect = [False, True] + mock_validate_model_id_and_get_type.side_effect = [False, True] JumpStartEstimator( model_id=model_id, ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -1694,7 +1698,7 @@ def test_model_id_not_found_refeshes_cache_training( ] ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @@ -1704,10 +1708,10 @@ def test_model_artifact_variant_estimator( mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "model-artifact-variant-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py index d22e910a00..1535039570 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py @@ -25,6 +25,7 @@ from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.estimator import JumpStartEstimator +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session from sagemaker.utils import resolve_value_from_config @@ -78,7 +79,7 @@ def config_value_impl(sagemaker_session: Session, config_path: str, sagemaker_co class IntelligentDefaultsEstimatorTest(unittest.TestCase): - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -98,12 +99,12 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_model_init_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" @@ -135,7 +136,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( assert "enable_network_isolation" not in mock_estimator_deploy.call_args[1] - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -155,12 +156,12 @@ def test_without_arg_overwrites_with_kwarg_collisions_with_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" @@ -208,7 +209,7 @@ def test_without_arg_overwrites_with_kwarg_collisions_with_config( config_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -228,12 +229,12 @@ def test_with_arg_overwrites_with_kwarg_collisions_with_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" @@ -290,7 +291,7 @@ def test_with_arg_overwrites_with_kwarg_collisions_with_config( override_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -310,12 +311,12 @@ def test_with_arg_overwrites_without_kwarg_collisions_with_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" @@ -365,7 +366,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_with_config( override_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.session.Session.get_caller_identity_arn") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @@ -387,12 +388,12 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, mock_get_caller_identity_arn: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" @@ -426,7 +427,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( assert "enable_network_isolation" not in mock_estimator_deploy.call_args[1] - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.session.Session.get_caller_identity_arn") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @@ -448,12 +449,12 @@ def test_without_arg_overwrites_with_kwarg_collisions_without_config( mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, mock_get_caller_identity_arn: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE mock_get_caller_identity_arn.return_value = execution_role model_id, _ = "js-trainable-model", "*" @@ -500,7 +501,7 @@ def test_without_arg_overwrites_with_kwarg_collisions_without_config( metadata_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -520,11 +521,11 @@ def test_with_arg_overwrites_with_kwarg_collisions_without_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" @@ -576,7 +577,7 @@ def test_with_arg_overwrites_with_kwarg_collisions_without_config( override_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_init_kwargs") @@ -594,11 +595,11 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index f45283935b..e0aea1bb2b 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -22,7 +22,7 @@ _retrieve_default_environment_variables, ) from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType from sagemaker.jumpstart.model import JumpStartModel from sagemaker.model import Model @@ -32,6 +32,7 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from tests.unit.sagemaker.jumpstart.utils import ( + get_spec_from_base_spec, get_special_model_spec, overwrite_dictionary, get_special_model_spec_for_inference_component_based_endpoint, @@ -53,7 +54,7 @@ class ModelTest(unittest.TestCase): @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -65,7 +66,7 @@ def test_non_prepacked( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_jumpstart_model_factory_logger: mock.Mock, ): @@ -73,7 +74,7 @@ def test_non_prepacked( mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -128,7 +129,7 @@ def test_non_prepacked( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -140,14 +141,14 @@ def test_non_prepacked_inference_component_based_endpoint( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): mock_model_deploy.return_value = default_predictor mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = ( @@ -208,7 +209,7 @@ def test_non_prepacked_inference_component_based_endpoint( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -220,14 +221,14 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): mock_model_deploy.return_value = default_predictor mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-model-class-model-prepacked", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -282,7 +283,7 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -294,11 +295,11 @@ def test_prepacked( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-model-class-model-prepacked", "*" @@ -345,7 +346,7 @@ def test_prepacked( @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -353,7 +354,7 @@ def test_no_compiled_model_warning_log_js_models( self, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_create_model: mock.Mock, mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, @@ -362,7 +363,7 @@ def test_no_compiled_model_warning_log_js_models( mock_timestamp.return_value = "1234" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "gated_llama_neuron_model", "*" @@ -381,7 +382,7 @@ def test_no_compiled_model_warning_log_js_models( @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -389,7 +390,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( self, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_create_model: mock.Mock, mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, @@ -397,7 +398,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( mock_timestamp.return_value = "1234" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "gated_variant-model", "*" @@ -441,7 +442,64 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( ], ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.__init__") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_proprietary_model_endpoint( + self, + mock_model_deploy: mock.Mock, + mock_model_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + ): + mock_model_deploy.return_value = default_predictor + + mock_sagemaker_timestamp.return_value = "7777" + + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.PROPRIETARY + model_id, _ = "ai21-summarization", "*" + + mock_get_model_specs.side_effect = get_spec_from_base_spec + + mock_session.return_value = sagemaker_session + + model = JumpStartModel( + model_id=model_id, + ) + + mock_model_init.assert_called_once_with( + image_uri="", + model_data="s3://jumpstart-cache-prod-us-west-2/None", + source_dir="s3://jumpstart-cache-prod-us-west-2/None", + entry_point="inference.py", + predictor_cls=Predictor, + role=execution_role, + sagemaker_session=sagemaker_session, + enable_network_isolation=False, + ) + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p4de.24xlarge", + wait=True, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "ai21-summarization"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "*"}, + ], + endpoint_logging=False, + model_data_download_timeout=3600, + container_startup_health_check_timeout=600, + ) + + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -451,11 +509,11 @@ def test_deprecated( mock_model_deploy: mock.Mock, mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "deprecated_model", "*" @@ -468,7 +526,7 @@ def test_deprecated( JumpStartModel(model_id=model_id, tolerate_deprecated_model=True).deploy() - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -478,9 +536,9 @@ def test_vulnerable( mock_model_deploy: mock.Mock, mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE mock_model_deploy.return_value = default_predictor @@ -543,7 +601,7 @@ def test_model_use_kwargs(self): ) @mock.patch("sagemaker.jumpstart.factory.model.environment_variables.retrieve_default") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -555,7 +613,7 @@ def evaluate_model_workflow_with_kwargs( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_retrieve_environment_variables: mock.Mock, init_kwargs: Optional[dict] = None, deploy_kwargs: Optional[dict] = None, @@ -565,7 +623,7 @@ def evaluate_model_workflow_with_kwargs( mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE mock_session.return_value = sagemaker_session @@ -661,22 +719,22 @@ def test_jumpstart_model_kwargs_match_parent_class(self): @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") - def test_is_valid_model_id( + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") + def test_validate_model_id_and_get_type( self, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE JumpStartModel(model_id="valid_model_id") - mock_is_valid_model_id.return_value = False + mock_validate_model_id_and_get_type.return_value = False with pytest.raises(ValueError): JumpStartModel(model_id="invalid_model_id") @mock.patch("sagemaker.jumpstart.model.get_default_predictor") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -688,14 +746,14 @@ def test_no_predictor_returns_default_predictor( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-model-class-model-prepacked", "*" @@ -717,12 +775,13 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, + model_type=JumpStartModelType.OPEN_SOURCE, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -734,14 +793,14 @@ def test_no_predictor_yes_async_inference_config( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-model-class-model-prepacked", "*" @@ -758,7 +817,7 @@ def test_no_predictor_yes_async_inference_config( mock_get_default_predictor.assert_not_called() @mock.patch("sagemaker.jumpstart.model.get_default_predictor") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -770,14 +829,14 @@ def test_yes_predictor_returns_default_predictor( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-model-class-model-prepacked", "*" @@ -793,24 +852,24 @@ def test_yes_predictor_returns_default_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.JumpStartModelsAccessor.reset_cache") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) - def test_model_id_not_found_refeshes_cach_inference( + def test_model_id_not_found_refeshes_cache_inference( self, mock_reset_cache: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.side_effect = [False, False] + mock_validate_model_id_and_get_type.side_effect = [False, False] model_id, _ = "js-trainable-model", "*" @@ -826,7 +885,7 @@ def test_model_id_not_found_refeshes_cach_inference( ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -845,16 +904,19 @@ def test_model_id_not_found_refeshes_cach_inference( ] ) - mock_is_valid_model_id.reset_mock() + mock_validate_model_id_and_get_type.reset_mock() mock_reset_cache.reset_mock() - mock_is_valid_model_id.side_effect = [False, True] + mock_validate_model_id_and_get_type.side_effect = [ + False, + JumpStartModelType.OPEN_SOURCE, + ] JumpStartModel( model_id=model_id, ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -873,16 +935,16 @@ def test_model_id_not_found_refeshes_cach_inference( ] ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_tags( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "env-var-variant-model", "*" @@ -909,16 +971,16 @@ def test_jumpstart_model_tags( [{"Key": "blah", "Value": "blahagain"}] + js_tags, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_tags_disabled( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "env-var-variant-model", "*" @@ -941,16 +1003,16 @@ def test_jumpstart_model_tags_disabled( [{"Key": "blah", "Value": "blahagain"}], ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_package_arn( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-model-package-arn", "*" @@ -975,16 +1037,16 @@ def test_jumpstart_model_package_arn( self.assertIn(tag, mock_session.create_model.call_args[1]["tags"]) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_package_arn_override( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE # arbitrary model without model packarn arn model_id, _ = "js-trainable-model", "*" @@ -1017,7 +1079,7 @@ def test_jumpstart_model_package_arn_override( }, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1025,10 +1087,10 @@ def test_jumpstart_model_package_arn_unsupported_region( self, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-model-package-arn", "*" @@ -1044,7 +1106,7 @@ def test_jumpstart_model_package_arn_unsupported_region( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -1058,14 +1120,14 @@ def test_model_data_s3_prefix_override( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): mock_model_deploy.return_value = default_predictor mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1109,7 +1171,7 @@ def test_model_data_s3_prefix_override( '"S3DataType": "S3Prefix", "CompressionType": "None"}}', ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -1123,11 +1185,11 @@ def test_model_data_s3_prefix_model( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "model_data_s3_prefix_model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1153,7 +1215,7 @@ def test_model_data_s3_prefix_model( mock_js_info_logger.assert_not_called() - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -1167,11 +1229,11 @@ def test_model_artifact_variant_model( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "model-artifact-variant-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1218,7 +1280,7 @@ def test_model_artifact_variant_model( enable_network_isolation=True, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -1230,11 +1292,11 @@ def test_model_registry_accept_and_response_types( mock_model_deploy: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "model_data_s3_prefix_model", "*" mock_get_model_specs.side_effect = get_special_model_spec diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py index 727f3060b3..02d170e54b 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py @@ -21,6 +21,7 @@ from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.model import JumpStartModel +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session from sagemaker.utils import resolve_value_from_config @@ -59,7 +60,7 @@ class IntelligentDefaultsModelTest(unittest.TestCase): region = "us-west-2" sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -75,10 +76,10 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" mock_retrieve_kwargs.return_value = {} @@ -100,7 +101,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( assert "enable_network_isolation" not in mock_model_init.call_args[1] - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -116,10 +117,10 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" @@ -146,7 +147,7 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( override_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -162,10 +163,10 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" @@ -192,7 +193,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( config_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -208,10 +209,10 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" @@ -240,7 +241,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( override_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -256,10 +257,10 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" @@ -286,7 +287,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( metadata_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -302,9 +303,9 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" @@ -333,7 +334,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( override_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -349,10 +350,10 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" @@ -374,7 +375,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( self.assertEquals(mock_model_init.call_args[1].get("role"), execution_role) assert "enable_network_isolation" not in mock_model_init.call_args[1] - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -390,10 +391,10 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, _ = "js-trainable-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 97427be1ae..b89e28d344 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -18,6 +18,7 @@ import pytest from sagemaker.jumpstart import accessors +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.constants import BASE_MANIFEST from tests.unit.sagemaker.jumpstart.utils import ( get_header_from_base_header, @@ -63,6 +64,49 @@ def test_jumpstart_models_cache_get_fxs(mock_cache): reload(accessors) +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") +def test_jumpstart_proprietary_models_cache_get_fxs(mock_cache): + + mock_cache.get_manifest = Mock(return_value=BASE_MANIFEST) + mock_cache.get_header = Mock(side_effect=get_header_from_base_header) + mock_cache.get_specs = Mock(side_effect=get_spec_from_base_spec) + + assert get_header_from_base_header( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) == accessors.JumpStartModelsAccessor.get_model_header( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert get_spec_from_base_spec( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) == accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + assert ( + len( + accessors.JumpStartModelsAccessor._get_manifest( + model_type=JumpStartModelType.PROPRIETARY + ) + ) + > 0 + ) + + # necessary because accessors is a static module + reload(accessors) + + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache") def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): @@ -138,6 +182,50 @@ def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): reload(accessors) +@patch("sagemaker.jumpstart.cache.JumpStartModelsCache") +def test_jumpstart_proprietary_models_cache_set_reset_fxs(mock_model_cache: Mock): + + # test change of region resets cache + accessors.JumpStartModelsAccessor.get_model_header( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + accessors.JumpStartModelsAccessor.get_model_header( + region="us-east-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-1", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + # necessary because accessors is a static module + reload(accessors) + + class TestS3Accessor(TestCase): bucket = "bucket" diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index 1a770f785f..173fa923a4 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -32,7 +32,7 @@ from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn from sagemaker.jumpstart.artifacts.model_uris import _retrieve_model_uri -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec from tests.unit.sagemaker.workflow.conftest import mock_client @@ -331,9 +331,13 @@ class RetrieveModelPackageArnTest(unittest.TestCase): mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - def test_retrieve_model_package_arn(self, patched_get_model_specs): + def test_retrieve_model_package_arn( + self, patched_get_model_specs: Mock, patched_validate_model_id_and_get_type: Mock + ): patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id = "variant-model" region = "us-west-2" @@ -437,9 +441,13 @@ class PrivateJumpStartBucketTest(unittest.TestCase): mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - def test_retrieve_uri_from_gated_bucket(self, patched_get_model_specs): + def test_retrieve_uri_from_gated_bucket( + self, patched_get_model_specs, patched_validate_model_id_and_get_type + ): patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id = "private-model" region = "us-west-2" diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 6633ecdc23..8b3f4be401 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -23,7 +23,11 @@ import pytest from mock import patch -from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache +from sagemaker.jumpstart.cache import ( + JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, + JumpStartModelsCache, +) from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, @@ -33,6 +37,7 @@ JumpStartModelSpecs, JumpStartVersionedModelId, ) +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import ( get_spec_from_base_spec, patched_retrieval_function, @@ -41,6 +46,8 @@ from tests.unit.sagemaker.jumpstart.constants import ( BASE_MANIFEST, BASE_SPEC, + BASE_PROPRIETARY_SPEC, + BASE_PROPRIETARY_MANIFEST, ) from sagemaker.jumpstart.utils import get_jumpstart_content_bucket @@ -152,6 +159,34 @@ def test_jumpstart_cache_get_header(): semantic_version_str="1.0.*", ) + assert JumpStartModelHeader( + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + } + ) == cache.get_header( + model_id="ai21-summarization", + semantic_version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + + assert JumpStartModelHeader( + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + } + ) == cache.get_header( + model_id="ai21-summarization", + semantic_version_str="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + with pytest.raises(KeyError) as e: cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", @@ -194,6 +229,32 @@ def test_jumpstart_cache_get_header(): "v3-classification-4'?" ) in str(e.value) + with pytest.raises(KeyError) as e: + cache.get_header( + model_id="ai21-summarize", + semantic_version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert ( + "Unable to find model manifest for 'ai21-summarize' with version '1.1.003'. " + "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " + "for updated list of models. " + "Did you mean to use model ID 'ai21-summarization'?" + ) in str(e.value) + + with pytest.raises(KeyError) as e: + cache.get_header( + model_id="ai21-summarize", + semantic_version_str="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert ( + "Unable to find model manifest for 'ai21-summarize' with version '*'. " + "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " + "for updated list of models. " + "Did you mean to use model ID 'ai21-summarization'?" + ) in str(e.value) + with pytest.raises(KeyError): cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", @@ -224,6 +285,27 @@ def test_jumpstart_cache_get_header(): semantic_version_str="*", ) + with pytest.raises(KeyError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="1.1.004", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError): + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="2.*", + model_type=JumpStartModelType.PROPRIETARY, + ) + @patch("boto3.client") def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client): @@ -423,11 +505,11 @@ def test_jumpstart_cache_accepts_input_parameters(): assert cache._s3_cache._max_cache_items == max_s3_cache_items assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon assert ( - cache._model_id_semantic_version_manifest_key_cache._max_cache_items + cache._open_source_model_id_manifest_key_cache._max_cache_items == max_semantic_version_cache_items ) assert ( - cache._model_id_semantic_version_manifest_key_cache._expiration_horizon + cache._open_source_model_id_manifest_key_cache._expiration_horizon == semantic_version_cache_expiration_horizon ) @@ -583,7 +665,7 @@ def test_jumpstart_cache_makes_correct_s3_calls( with patch("logging.Logger.warning") as mocked_warning_log: cache.get_specs( - model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + model_id="pytorch-ic-imagenet-inception-v3-classification-4", version_str="*" ) mocked_warning_log.assert_called_once_with( "Using model 'pytorch-ic-imagenet-inception-v3-classification-4' with wildcard " @@ -593,7 +675,7 @@ def test_jumpstart_cache_makes_correct_s3_calls( ) mocked_warning_log.reset_mock() cache.get_specs( - model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + model_id="pytorch-ic-imagenet-inception-v3-classification-4", version_str="*" ) mocked_warning_log.assert_not_called() @@ -605,13 +687,97 @@ def test_jumpstart_cache_makes_correct_s3_calls( mock_boto3_client.return_value.head_object.assert_not_called() +@patch("sagemaker.jumpstart.cache.utils.emit_logs_based_on_model_specs") +@patch("boto3.client") +def test_jumpstart_cache_proprietary_manifest_makes_correct_s3_calls( + mock_boto3_client, mock_emit_logs_based_on_model_specs +): + + # test get_header + mock_manifest_json = json.dumps( + [ + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + }, + ] + ) + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_manifest_json, "utf-8")), content_length=len(mock_manifest_json) + ), + "ETag": "etag", + } + + mock_boto3_client.return_value.head_object.return_value = {"ETag": "some-hash"} + + bucket_name = get_jumpstart_content_bucket("us-west-2") + client_config = botocore.config.Config(signature_version="my_signature_version") + cache = JumpStartModelsCache( + s3_bucket_name=bucket_name, s3_client_config=client_config, region="us-west-2" + ) + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY + ) + mock_boto3_client.return_value.head_object.assert_not_called() + + mock_boto3_client.assert_called_with("s3", region_name="us-west-2", config=client_config) + + # test get_specs. manifest already in cache, so only s3 call will be to get specs. + mock_json = json.dumps(BASE_PROPRIETARY_SPEC) + + mock_boto3_client.return_value.reset_mock() + + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + ), + "ETag": "etag", + } + + with patch("logging.Logger.warning") as mocked_warning_log: + cache.get_specs( + model_id="ai21-summarization", + version_str="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + mocked_warning_log.assert_called_once_with( + "Using model 'ai21-summarization' with wildcard " + "version identifier '*'. You can pin to version '1.1.003' for more " + "stable results. Note that models may have different input/output " + "signatures after a major version upgrade." + ) + mocked_warning_log.reset_mock() + cache.get_specs( + model_id="ai21-summarization", + version_str="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + mocked_warning_log.assert_not_called() + + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, + Key="proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + ) + mock_boto3_client.return_value.head_object.assert_not_called() + + @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") cache.clear = MagicMock() - cache._model_id_semantic_version_manifest_key_cache = MagicMock() - cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [ + cache._open_source_model_id_manifest_key_cache = MagicMock() + cache._open_source_model_id_manifest_key_cache.get.side_effect = [ ( JumpStartVersionedModelId( "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" @@ -640,7 +806,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache.clear.assert_called_once() cache.clear.reset_mock() - cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [ + cache._open_source_model_id_manifest_key_cache.get.side_effect = [ ( JumpStartVersionedModelId( "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" @@ -668,7 +834,18 @@ def test_jumpstart_get_full_manifest(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") raw_manifest = [header.to_json() for header in cache.get_manifest()] - raw_manifest == BASE_MANIFEST + assert raw_manifest == BASE_MANIFEST + + +@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +def test_jumpstart_get_full_proprietary_manifest(): + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + raw_manifest = [ + header.to_json() for header in cache.get_manifest(model_type=JumpStartModelType.PROPRIETARY) + ] + + assert raw_manifest == BASE_PROPRIETARY_MANIFEST @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @@ -678,54 +855,89 @@ def test_jumpstart_cache_get_specs(): model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version ) model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" assert get_spec_from_base_spec(model_id=model_id, version="2.0.0") == cache.get_specs( - model_id=model_id, semantic_version_str="2.0.*" + model_id=model_id, version_str="2.0.*" ) model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0" assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version ) model_id = "pytorch-ic-imagenet-inception-v3-classification-4" assert get_spec_from_base_spec(model_id=model_id, version="1.0.0") == cache.get_specs( - model_id=model_id, semantic_version_str="1.*" + model_id=model_id, version_str="1.*" ) model_id = "pytorch-ic-imagenet-inception-v3-classification-4" assert get_spec_from_base_spec(model_id=model_id, version="1.0.0") == cache.get_specs( - model_id=model_id, semantic_version_str="1.0.*" + model_id=model_id, version_str="1.0.*" + ) + + assert get_spec_from_base_spec( + model_id="ai21-summarization", + version="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) == cache.get_specs( + model_id="ai21-summarization", + version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + + assert get_spec_from_base_spec( + model_id="ai21-summarization", + version="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) == cache.get_specs( + model_id="ai21-summarization", + version_str="*", + model_type=JumpStartModelType.PROPRIETARY, ) with pytest.raises(KeyError): - cache.get_specs(model_id=model_id + "bak", semantic_version_str="*") + cache.get_specs(model_id=model_id + "bak", version_str="*") with pytest.raises(KeyError): - cache.get_specs(model_id=model_id, semantic_version_str="9.*") + cache.get_specs(model_id=model_id, version_str="9.*") with pytest.raises(KeyError): - cache.get_specs(model_id=model_id, semantic_version_str="BAD") + cache.get_specs(model_id=model_id, version_str="BAD") with pytest.raises(KeyError): cache.get_specs( model_id=model_id, - semantic_version_str="2.1.*", + version_str="2.1.*", ) with pytest.raises(KeyError): cache.get_specs( model_id=model_id, - semantic_version_str="3.9.*", + version_str="3.9.*", ) with pytest.raises(KeyError): cache.get_specs( model_id=model_id, - semantic_version_str="5.*", + version_str="5.*", + ) + + model_id, version = "ai21-summarization", "2.0.0" + with pytest.raises(KeyError): + cache.get_specs( + model_id=model_id, + version_str="BAD", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError): + cache.get_specs( + model_id=model_id, + version_str="9.*", + model_type=JumpStartModelType.PROPRIETARY, ) @@ -794,9 +1006,7 @@ def test_jumpstart_local_metadata_override_specs( cache = JumpStartModelsCache(s3_bucket_name="some_bucket") model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" - assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs( - model_id=model_id, semantic_version_str=version - ) + assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs(model_id=model_id, version_str=version) mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root") mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") @@ -840,7 +1050,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( cache = JumpStartModelsCache(s3_bucket_name="some_bucket") assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version ) mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 1a7108579c..76a1b684a0 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -17,6 +17,7 @@ get_prototype_manifest, get_prototype_model_spec, ) +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.notebook_utils import ( _generate_jumpstart_model_versions, get_model_url, @@ -63,7 +64,7 @@ def test_list_jumpstart_scripts( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 3 assert patched_get_model_specs.call_count == 1 patched_get_model_specs.reset_mock() @@ -670,12 +671,12 @@ def test_list_jumpstart_models_multiple_level_index( list_jumpstart_models("hosting_ecr_specs.py_version == py3") +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_get_model_url( - patched_get_model_specs: Mock, -): +def test_get_model_url(patched_get_model_specs: Mock, patched_validate_model_id_and_get_type: Mock): patched_get_model_specs.side_effect = get_prototype_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, version = "xgboost-classification-model", "1.0.0" assert "https://xgboost.readthedocs.io/en/latest/" == get_model_url(model_id, version) @@ -702,4 +703,5 @@ def test_get_model_url( version=version, region=region, s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 7ab9cdd1cc..f2adcf5298 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -14,10 +14,8 @@ from sagemaker.jumpstart.utils import verify_model_region_and_return_specs -from sagemaker.serializers import IdentitySerializer -from tests.unit.sagemaker.jumpstart.utils import ( - get_special_model_spec, -) +from sagemaker.serializers import IdentitySerializer, JSONSerializer +from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec, get_spec_from_base_spec @patch("sagemaker.predictor.get_model_id_version_from_endpoint") @@ -54,6 +52,40 @@ def test_jumpstart_predictor_support( assert js_predictor.accept == MIMEType.JSON +@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_proprietary_predictor_support( + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_get_jumpstart_model_id_version_from_endpoint, +): + + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs + patched_get_model_specs.side_effect = get_spec_from_base_spec + + # version not needed for JumpStart predictor + model_id, model_version = "ai21-summarization", "*" + + patched_get_jumpstart_model_id_version_from_endpoint.return_value = ( + model_id, + model_version, + None, + ) + + js_predictor = predictor.retrieve_default( + endpoint_name="blah", model_id=model_id, model_version=model_version + ) + + patched_get_jumpstart_model_id_version_from_endpoint.assert_not_called() + + assert js_predictor.content_type == MIMEType.JSON + assert isinstance(js_predictor.serializer, JSONSerializer) + + assert isinstance(js_predictor.deserializer, JSONDeserializer) + assert js_predictor.accept == MIMEType.JSON + + @patch("sagemaker.predictor.Predictor") @patch("sagemaker.predictor.get_default_predictor") @patch("sagemaker.predictor.get_model_id_version_from_endpoint") @@ -125,19 +157,19 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( @patch("sagemaker.predictor.get_model_id_version_from_endpoint") @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") -@patch("sagemaker.jumpstart.model.is_valid_model_id") +@patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_serializable_payload_with_predictor( patched_get_model_specs, patched_verify_model_region_and_return_specs, - patched_is_valid_model_id, + patched_validate_model_id_and_get_type, patched_get_object_cached, patched_get_model_id_version_from_endpoint, ): patched_get_object_cached.return_value = base64.b64decode("encodedimage") - patched_is_valid_model_id.return_value = True + patched_validate_model_id_and_get_type.return_value = True patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 556b99bc9c..3ec6f8aec3 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -32,7 +32,7 @@ JumpStartScriptScope, ) from functools import partial -from sagemaker.jumpstart.enums import JumpStartTag, MIMEType +from sagemaker.jumpstart.enums import JumpStartTag, MIMEType, JumpStartModelType from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, VulnerableJumpStartModelError, @@ -68,7 +68,7 @@ def test_get_jumpstart_content_bucket_override(): with patch("logging.Logger.info") as mocked_info_log: random_region = "random_region" assert "some-val" == utils.get_jumpstart_content_bucket(random_region) - mocked_info_log.assert_called_once_with("Using JumpStart bucket override: 'some-val'") + mocked_info_log.assert_called_with("Using JumpStart bucket override: 'some-val'") def test_get_jumpstart_gated_content_bucket(): @@ -1180,7 +1180,7 @@ def test_mime_type_enum_from_str(): class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") - def test_is_valid_model_id_true( + def test_validate_model_id_and_get_type_true( self, mock_get_model_specs: Mock, mock_get_manifest: Mock, @@ -1194,12 +1194,16 @@ def test_is_valid_model_id_true( mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION mock_s3_client_value = mock_session_value.s3_client - patched = partial(utils.is_valid_model_id, sagemaker_session=mock_session_value) + patched = partial( + utils.validate_model_id_and_get_type, sagemaker_session=mock_session_value + ) - with patch("sagemaker.jumpstart.utils.is_valid_model_id", patched): - self.assertTrue(utils.is_valid_model_id("bee")) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched): + self.assertTrue(utils.validate_model_id_and_get_type("bee")) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) mock_get_model_specs.assert_not_called() @@ -1213,14 +1217,20 @@ def test_is_valid_model_id_true( ] mock_get_model_specs.return_value = Mock(training_supported=True) - self.assertTrue(utils.is_valid_model_id("bee", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + self.assertTrue( + utils.validate_model_id_and_get_type("bee", script=JumpStartScriptScope.TRAINING) + ) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") - def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_manifest: Mock): + def test_validate_model_id_and_get_type_false( + self, mock_get_model_specs: Mock, mock_get_manifest: Mock + ): mock_get_manifest.return_value = [ Mock(model_id="ay"), Mock(model_id="bee"), @@ -1230,18 +1240,18 @@ def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_mani mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION mock_s3_client_value = mock_session_value.s3_client - patched = partial(utils.is_valid_model_id, sagemaker_session=mock_session_value) + patched = partial( + utils.validate_model_id_and_get_type, sagemaker_session=mock_session_value + ) - with patch("sagemaker.jumpstart.utils.is_valid_model_id", patched): + with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched): - self.assertFalse(utils.is_valid_model_id("dee")) - self.assertFalse(utils.is_valid_model_id("")) - self.assertFalse(utils.is_valid_model_id(None)) - self.assertFalse(utils.is_valid_model_id(set())) + self.assertFalse(utils.validate_model_id_and_get_type("dee")) + self.assertFalse(utils.validate_model_id_and_get_type("")) + self.assertFalse(utils.validate_model_id_and_get_type(None)) + self.assertFalse(utils.validate_model_id_and_get_type(set())) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value - ) + mock_get_manifest.assert_called() mock_get_model_specs.assert_not_called() @@ -1253,30 +1263,48 @@ def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_mani Mock(model_id="bee"), Mock(model_id="see"), ] - self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + self.assertFalse( + utils.validate_model_id_and_get_type("dee", script=JumpStartScriptScope.TRAINING) + ) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) mock_get_manifest.reset_mock() - self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id("", script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id(None, script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id(set(), script=JumpStartScriptScope.TRAINING)) + self.assertFalse( + utils.validate_model_id_and_get_type("dee", script=JumpStartScriptScope.TRAINING) + ) + self.assertFalse( + utils.validate_model_id_and_get_type("", script=JumpStartScriptScope.TRAINING) + ) + self.assertFalse( + utils.validate_model_id_and_get_type(None, script=JumpStartScriptScope.TRAINING) + ) + self.assertFalse( + utils.validate_model_id_and_get_type(set(), script=JumpStartScriptScope.TRAINING) + ) mock_get_model_specs.assert_not_called() - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) mock_get_manifest.reset_mock() mock_get_model_specs.reset_mock() mock_get_model_specs.return_value = Mock(training_supported=False) - self.assertTrue(utils.is_valid_model_id("ay", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + self.assertTrue( + utils.validate_model_id_and_get_type("ay", script=JumpStartScriptScope.TRAINING) + ) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 146c6fd1f7..f8c6384e81 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -28,12 +28,16 @@ JumpStartS3FileType, JumpStartModelHeader, ) +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import get_formatted_manifest from tests.unit.sagemaker.jumpstart.constants import ( PROTOTYPICAL_MODEL_SPECS_DICT, BASE_MANIFEST, BASE_SPEC, + BASE_PROPRIETARY_MANIFEST, + BASE_PROPRIETARY_SPEC, BASE_HEADER, + BASE_PROPRIETARY_HEADER, SPECIAL_MODEL_SPECS_DICT, ) @@ -44,11 +48,19 @@ def get_header_from_base_header( model_id: str = None, semantic_version_str: str = None, version: str = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> JumpStartModelHeader: if version and semantic_version_str: raise ValueError("Cannot specify both `version` and `semantic_version_str` fields.") + if model_type == JumpStartModelType.PROPRIETARY: + spec = copy.deepcopy(BASE_PROPRIETARY_HEADER) + spec["version"] = version or semantic_version_str + spec["model_id"] = model_id + + return JumpStartModelHeader(spec) + if all( [ "pytorch" not in model_id, @@ -92,6 +104,7 @@ def get_prototype_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. @@ -107,6 +120,7 @@ def get_special_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. This is reserved @@ -122,6 +136,7 @@ def get_special_model_spec_for_inference_component_based_endpoint( model_id: str = None, version: str = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID and adding @@ -142,14 +157,22 @@ def get_spec_from_base_spec( _obj: JumpStartModelsCache = None, region: str = None, model_id: str = None, - semantic_version_str: str = None, + version_str: str = None, version: str = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> JumpStartModelSpecs: - if version and semantic_version_str: + if version and version_str: raise ValueError("Cannot specify both `version` and `semantic_version_str` fields.") + if model_type == JumpStartModelType.PROPRIETARY: + spec = copy.deepcopy(BASE_PROPRIETARY_SPEC) + spec["version"] = version or version_str + spec["model_id"] = model_id + + return JumpStartModelSpecs(spec) + if all( [ "pytorch" not in model_id, @@ -172,7 +195,7 @@ def get_spec_from_base_spec( spec = copy.deepcopy(BASE_SPEC) - spec["version"] = version or semantic_version_str + spec["version"] = version or version_str spec["model_id"] = model_id return JumpStartModelSpecs(spec) @@ -185,19 +208,35 @@ def patched_retrieval_function( ) -> JumpStartCachedS3ContentValue: filetype, s3_key = key.file_type, key.s3_key - if filetype == JumpStartS3FileType.MANIFEST: + if filetype == JumpStartS3FileType.OPEN_SOURCE_MANIFEST: return JumpStartCachedS3ContentValue( formatted_content=get_formatted_manifest(BASE_MANIFEST) ) - if filetype == JumpStartS3FileType.SPECS: + if filetype == JumpStartS3FileType.OPEN_SOURCE_SPECS: _, model_id, specs_version = s3_key.split("/") version = specs_version.replace("specs_v", "").replace(".json", "") return JumpStartCachedS3ContentValue( formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) ) + if filetype == JumpStartS3FileType.PROPRIETARY_MANIFEST: + return JumpStartCachedS3ContentValue( + formatted_content=get_formatted_manifest(BASE_PROPRIETARY_MANIFEST) + ) + + if filetype == JumpStartS3FileType.PROPRIETARY_SPECS: + _, model_id, specs_version = s3_key.split("/") + version = specs_version.replace("proprietary_specs_", "").replace(".json", "") + return JumpStartCachedS3ContentValue( + formatted_content=get_spec_from_base_spec( + model_id=model_id, + version=version, + model_type=JumpStartModelType.PROPRIETARY, + ) + ) + raise ValueError(f"Bad value for filetype: {filetype}") diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index ffc6000c91..19927c9b16 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -18,6 +18,7 @@ import pytest from sagemaker import metric_definitions +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec @@ -25,10 +26,14 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_default_metric_definitions(patched_get_model_specs): +def test_jumpstart_default_metric_definitions( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -47,7 +52,11 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() @@ -63,7 +72,11 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 000540e12e..f1d5441a71 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -19,19 +19,24 @@ from sagemaker import model_uris from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.model_uris.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_common_model_uri( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -47,6 +52,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -64,6 +70,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -82,6 +89,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -100,6 +108,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 28b53270f8..fa1900d3b3 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -18,14 +18,19 @@ import pytest from sagemaker import resource_requirements +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_resource_requirements(patched_get_model_specs): +def test_jumpstart_resource_requirements( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE region = "us-west-2" mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -46,13 +51,18 @@ def test_jumpstart_resource_requirements(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): +def test_jumpstart_no_supported_resource_requirements( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, model_version = "no-supported-instance-types-model", "*" region = "us-west-2" @@ -73,6 +83,7 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 3f38326608..0dfe677936 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -19,19 +19,24 @@ from sagemaker import script_uris from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.script_uris.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_common_script_uri( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -47,6 +52,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -64,6 +70,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -78,7 +85,11 @@ def test_jumpstart_common_script_uri( sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -97,6 +108,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index b22b61dc40..8eeb867d71 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -19,18 +19,23 @@ from sagemaker import base_serializers, serializers from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_serializers( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -50,19 +55,24 @@ def test_jumpstart_default_serializers( model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) patched_get_model_specs.reset_mock() +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_serializer_options( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -89,4 +99,5 @@ def test_jumpstart_serializer_options( model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) From 4d58cfd049427e9e97fb85926f2f499f1e22247e Mon Sep 17 00:00:00 2001 From: Haotian An Date: Thu, 29 Feb 2024 15:51:36 +0000 Subject: [PATCH 02/30] remove unused imports and fix docstyle --- src/sagemaker/accept_types.py | 2 -- src/sagemaker/content_types.py | 2 -- src/sagemaker/deserializers.py | 2 -- src/sagemaker/jumpstart/factory/model.py | 4 ++-- src/sagemaker/serializers.py | 2 -- 5 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index cce0b653f8..bf081365ab 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -16,7 +16,6 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -73,7 +72,6 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index fa3d49fbba..e43e96be17 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -16,7 +16,6 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -73,7 +72,6 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 4cb596ca48..706ae56bda 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -35,7 +35,6 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -96,7 +95,6 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 4fda2cbf4c..585d79b2a6 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -204,8 +204,8 @@ def _add_instance_type_to_kwargs( def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: - """ - Sets image uri based on default or override, returns full kwargs. + """Sets image uri based on default or override, returns full kwargs. + Uses placeholder image uri for JumpStart proprietary models that uses ModelPackages """ diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index b056ff593c..fc76c0fa76 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -34,7 +34,6 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -92,7 +91,6 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, From 8ba6477c099f41c58214bd63a13fe5172e47c143 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Thu, 29 Feb 2024 16:45:33 +0000 Subject: [PATCH 03/30] fix: remove unused args --- src/sagemaker/jumpstart/factory/model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 585d79b2a6..d18f77f4c7 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -93,7 +93,6 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, - model_type=model_type, ) predictor.deserializer = deserializers.retrieve_default( model_id=model_id, @@ -102,7 +101,6 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, - model_type=model_type, ) predictor.accept = accept_types.retrieve_default( model_id=model_id, @@ -111,7 +109,6 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, - model_type=model_type, ) predictor.content_type = content_types.retrieve_default( model_id=model_id, @@ -120,7 +117,6 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, - model_type=model_type, ) return predictor From 29ac23c8e8b8f2c58625f5aedfeb72d054ab4e2d Mon Sep 17 00:00:00 2001 From: Haotian An Date: Thu, 29 Feb 2024 16:53:33 +0000 Subject: [PATCH 04/30] fix: remove unused args --- src/sagemaker/jumpstart/factory/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index d18f77f4c7..c95f1c87e6 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -71,7 +71,6 @@ def get_default_predictor( tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, sagemaker_session: Session, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. From 57a7d3764dbbc470335117264d1a2542c574a7d8 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Thu, 29 Feb 2024 17:08:32 +0000 Subject: [PATCH 05/30] fix: more unused vars --- src/sagemaker/jumpstart/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 837291b813..d952c43258 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -605,7 +605,6 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, - model_type=self._model_type, ) # If a predictor class was passed, do not mutate predictor From a8ffdc296bc4099a71323da09dfa47ca4dffd3be Mon Sep 17 00:00:00 2001 From: Haotian An Date: Thu, 29 Feb 2024 18:41:22 +0000 Subject: [PATCH 06/30] fix: slow tests --- .../sagemaker/jumpstart/estimator/test_jumpstart_estimator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py index 5faa40ccda..894a2d0f38 100644 --- a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py +++ b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py @@ -33,7 +33,7 @@ from sagemaker.jumpstart.utils import get_jumpstart_content_bucket -MAX_INIT_TIME_SECONDS = 5 +MAX_INIT_TIME_SECONDS = 15 GATED_TRAINING_MODEL_V1_SUPPORTED_REGIONS = { "us-west-2", From d71b72739bc0888fe5ba3aedbc7930c6c0288a60 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Thu, 29 Feb 2024 21:24:03 +0000 Subject: [PATCH 07/30] fix: unittests --- tests/unit/sagemaker/jumpstart/model/test_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index e0aea1bb2b..597c76e5af 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -775,7 +775,6 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, - model_type=JumpStartModelType.OPEN_SOURCE, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) From 359ea1c9edd8b06c5cf7db6f000158d3e4ea0121 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Thu, 29 Feb 2024 22:09:10 +0000 Subject: [PATCH 08/30] added more tests to cover some lines --- src/sagemaker/jumpstart/cache.py | 4 +- tests/unit/sagemaker/jumpstart/test_cache.py | 72 ++++++++++++++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 888b1d07d1..6e7e192633 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -164,8 +164,8 @@ def set_manifest_file_s3_key( if not property_name: raise ValueError( f"Bad value when setting manifest '{file_type}': must be in" - f"{JumpStartS3FileType.OPEN_SOURCE_MANIFEST}" - f"{JumpStartS3FileType.PROPRIETARY_MANIFEST}" + f" {JumpStartS3FileType.OPEN_SOURCE_MANIFEST}" + f" {JumpStartS3FileType.PROPRIETARY_MANIFEST}" ) if key != property_name: setattr(self, property_name, key) diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 8b3f4be401..f75bec58df 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -36,6 +36,7 @@ JumpStartModelHeader, JumpStartModelSpecs, JumpStartVersionedModelId, + JumpStartS3FileType, ) from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import ( @@ -358,6 +359,12 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client): cache.set_manifest_file_s3_key("some_key1") cache.clear.assert_called_once() + cache.clear.reset_mock() + cache.set_manifest_file_s3_key("some_key1", file_type=JumpStartS3FileType.OPEN_SOURCE_MANIFEST) + cache.clear.assert_called_once() + with pytest.raises(ValueError): + cache.set_manifest_file_s3_key("some_key1", file_type="unknown_type") + def test_jumpstart_cache_handles_boto3_client_errors(): # Testing get_object @@ -514,6 +521,71 @@ def test_jumpstart_cache_accepts_input_parameters(): ) +def test_jumpstart_proprietary_cache_accepts_input_parameters(): + + region = "us-east-1" + max_s3_cache_items = 1 + s3_cache_expiration_horizon = datetime.timedelta(weeks=2) + max_semantic_version_cache_items = 3 + semantic_version_cache_expiration_horizon = datetime.timedelta(microseconds=4) + bucket = "my-amazing-bucket" + manifest_file_key = "some_s3_key" + proprietary_manifest_file_key = "some_proprietary_s3_key" + + cache = JumpStartModelsCache( + region=region, + max_s3_cache_items=max_s3_cache_items, + s3_cache_expiration_horizon=s3_cache_expiration_horizon, + max_semantic_version_cache_items=max_semantic_version_cache_items, + semantic_version_cache_expiration_horizon=semantic_version_cache_expiration_horizon, + s3_bucket_name=bucket, + manifest_file_s3_key=manifest_file_key, + proprietary_manifest_s3_key=proprietary_manifest_file_key, + ) + + assert ( + cache.get_manifest_file_s3_key(file_type=JumpStartS3FileType.PROPRIETARY_MANIFEST) + == proprietary_manifest_file_key + ) + assert cache.get_region() == region + assert cache.get_bucket() == bucket + assert cache._s3_cache._max_cache_items == max_s3_cache_items + assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert ( + cache._proprietary_model_id_manifest_key_cache._max_cache_items + == max_semantic_version_cache_items + ) + assert ( + cache._proprietary_model_id_manifest_key_cache._expiration_horizon + == semantic_version_cache_expiration_horizon + ) + + +def test_jumpstart_cache_raise_unknown_file_type_exception(): + + region = "us-east-1" + max_s3_cache_items = 1 + s3_cache_expiration_horizon = datetime.timedelta(weeks=2) + max_semantic_version_cache_items = 3 + semantic_version_cache_expiration_horizon = datetime.timedelta(microseconds=4) + bucket = "my-amazing-bucket" + manifest_file_key = "some_s3_key" + proprietary_manifest_file_key = "some_proprietary_s3_key" + + cache = JumpStartModelsCache( + region=region, + max_s3_cache_items=max_s3_cache_items, + s3_cache_expiration_horizon=s3_cache_expiration_horizon, + max_semantic_version_cache_items=max_semantic_version_cache_items, + semantic_version_cache_expiration_horizon=semantic_version_cache_expiration_horizon, + s3_bucket_name=bucket, + manifest_file_s3_key=manifest_file_key, + proprietary_manifest_s3_key=proprietary_manifest_file_key, + ) + with pytest.raises(ValueError): + cache.get_manifest_file_s3_key(file_type="unknown_type") + + @patch("boto3.client") def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): From 04e1376fb89d408c3a86e8788ec381e0f99a05a2 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Fri, 1 Mar 2024 14:26:59 +0000 Subject: [PATCH 09/30] remove estimator warn check --- .../sagemaker/jumpstart/estimator/test_jumpstart_estimator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py index 894a2d0f38..cbaea08a43 100644 --- a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py +++ b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py @@ -235,5 +235,3 @@ def test_instatiating_estimator(mock_warning_logger, setup): elapsed_time = time.perf_counter() - start_time assert elapsed_time <= MAX_INIT_TIME_SECONDS - - mock_warning_logger.assert_called_once() From f74a3e4a57c84e156fce873528f0f0f1dd124e53 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Fri, 1 Mar 2024 20:32:00 +0000 Subject: [PATCH 10/30] chore: address comments re performance --- src/sagemaker/accept_types.py | 3 + src/sagemaker/content_types.py | 3 + src/sagemaker/deserializers.py | 3 + src/sagemaker/instance_types.py | 3 + .../jumpstart/artifacts/instance_types.py | 3 + src/sagemaker/jumpstart/artifacts/kwargs.py | 5 ++ .../jumpstart/artifacts/model_packages.py | 3 + .../jumpstart/artifacts/predictors.py | 15 ++++ .../jumpstart/artifacts/resource_names.py | 3 + .../artifacts/resource_requirements.py | 3 + src/sagemaker/jumpstart/cache.py | 36 +++++++--- src/sagemaker/jumpstart/estimator.py | 7 +- src/sagemaker/jumpstart/exceptions.py | 10 +++ src/sagemaker/jumpstart/factory/estimator.py | 4 +- src/sagemaker/jumpstart/factory/model.py | 31 +++++++++ src/sagemaker/jumpstart/model.py | 12 ++-- src/sagemaker/jumpstart/types.py | 12 ++++ src/sagemaker/jumpstart/utils.py | 4 +- src/sagemaker/predictor.py | 3 + src/sagemaker/resource_requirements.py | 3 + src/sagemaker/serializers.py | 3 + .../estimator/test_jumpstart_estimator.py | 4 +- .../jumpstart/model/test_jumpstart_model.py | 9 ++- .../sagemaker/jumpstart/model/test_model.py | 4 +- tests/unit/sagemaker/jumpstart/test_cache.py | 69 ++++++------------- .../jumpstart/test_notebook_utils.py | 2 +- .../sagemaker/jumpstart/test_predictor.py | 8 ++- 27 files changed, 187 insertions(+), 78 deletions(-) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index bf081365ab..86b78dbbfc 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -16,6 +16,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -75,6 +76,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> str: """Retrieves the default accept type for the model matching the given arguments. @@ -114,4 +116,5 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index e43e96be17..6bbcff876c 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -16,6 +16,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -75,6 +76,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> str: """Retrieves the default content type for the model matching the given arguments. @@ -114,6 +116,7 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 706ae56bda..7bd8315d03 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -35,6 +35,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -95,6 +96,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. @@ -135,4 +137,5 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 0471f374ae..414b28fee2 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -20,6 +20,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -34,6 +35,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> str: """Retrieves the default instance type for the model matching the given arguments. @@ -85,6 +87,7 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, training_instance_type=training_instance_type, + model_type=model_type, ) diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 38e02e3ebd..47b7849dc5 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -22,6 +22,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -38,6 +39,7 @@ def _retrieve_default_instance_type( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> str: """Retrieves the default instance type for the model. @@ -84,6 +86,7 @@ def _retrieve_default_instance_type( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 7acad9b793..3a7012d4f9 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -22,6 +22,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -35,6 +36,7 @@ def _retrieve_model_init_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> dict: """Retrieves kwargs for `Model`. @@ -71,6 +73,7 @@ def _retrieve_model_init_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) kwargs = deepcopy(model_specs.model_kwargs) @@ -89,6 +92,7 @@ def _retrieve_model_deploy_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> dict: """Retrieves kwargs for `Model.deploy`. @@ -128,6 +132,7 @@ def _retrieve_model_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None: diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index bd0ae365d9..89c8e8ebd0 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -22,6 +22,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.session import Session @@ -35,6 +36,7 @@ def _retrieve_model_package_arn( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> Optional[str]: """Retrieves associated model pacakge arn for the model. @@ -74,6 +76,7 @@ def _retrieve_model_package_arn( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 8d599c89cc..4f47fee0be 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -26,6 +26,7 @@ from sagemaker.jumpstart.enums import ( JumpStartScriptScope, MIMEType, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -76,6 +77,7 @@ def _retrieve_default_deserializer( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> BaseDeserializer: """Retrieves the default deserializer for the model. @@ -108,6 +110,7 @@ def _retrieve_default_deserializer( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) return _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(default_accept_type)) @@ -120,6 +123,7 @@ def _retrieve_default_serializer( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> BaseSerializer: """Retrieves the default serializer for the model. @@ -151,6 +155,7 @@ def _retrieve_default_serializer( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) return _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(default_content_type)) @@ -163,6 +168,7 @@ def _retrieve_deserializer_options( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model. @@ -194,6 +200,7 @@ def _retrieve_deserializer_options( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) seen_classes: Set[Type] = set() @@ -276,6 +283,7 @@ def _retrieve_default_content_type( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default content type for the model. @@ -312,6 +320,7 @@ def _retrieve_default_content_type( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) default_content_type = model_specs.predictor_specs.default_content_type @@ -325,6 +334,7 @@ def _retrieve_default_accept_type( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> str: """Retrieves the default accept type for the model. @@ -360,6 +370,7 @@ def _retrieve_default_accept_type( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) default_accept_type = model_specs.predictor_specs.default_accept_type @@ -374,6 +385,7 @@ def _retrieve_supported_accept_types( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> List[str]: """Retrieves the supported accept types for the model. @@ -409,6 +421,7 @@ def _retrieve_supported_accept_types( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) supported_accept_types = model_specs.predictor_specs.supported_accept_types @@ -423,6 +436,7 @@ def _retrieve_supported_content_types( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> List[str]: """Retrieves the supported content types for the model. @@ -458,6 +472,7 @@ def _retrieve_supported_content_types( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) supported_content_types = model_specs.predictor_specs.supported_content_types diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index 6b05f07b15..a46191be95 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -19,6 +19,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -32,6 +33,7 @@ def _retrieve_resource_name_base( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns default resource name. @@ -68,6 +70,7 @@ def _retrieve_resource_name_base( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 8356d1efac..92e1c7dad9 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -21,6 +21,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -36,6 +37,7 @@ def _retrieve_default_resources( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> ResourceRequirements: """Retrieves the default resource requirements for the model. @@ -80,6 +82,7 @@ def _retrieve_default_resources( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 6e7e192633..12923a2f37 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -32,7 +32,10 @@ MODEL_TYPE_TO_MANIFEST_MAP, MODEL_TYPE_TO_SPECS_MAP, ) -from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg +from sagemaker.jumpstart.exceptions import ( + get_wildcard_model_version_msg, + get_wildcard_proprietary_model_version_msg, +) from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, @@ -220,7 +223,6 @@ def _model_id_retrieval_function( """ model_id, version = key.model_id, key.version - sm_version = utils.get_sagemaker_version() manifest = self._s3_cache.get( JumpStartCachedS3ContentKey( @@ -234,7 +236,7 @@ def _model_id_retrieval_function( ] sm_compatible_model_version = self._select_version( - version, versions_compatible_with_sagemaker + model_id, version, versions_compatible_with_sagemaker, model_type ) if sm_compatible_model_version is not None: @@ -245,7 +247,7 @@ def _model_id_retrieval_function( if header.model_id == model_id ] sm_incompatible_model_version = self._select_version( - version, versions_incompatible_with_sagemaker + model_id, version, versions_incompatible_with_sagemaker, model_type ) if sm_incompatible_model_version is not None: @@ -275,15 +277,17 @@ def _model_id_retrieval_function( f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. " ) - other_model_id_version = self._select_version( - "*", versions_incompatible_with_sagemaker - ) # all versions here are incompatible with sagemaker + other_model_id_version = None + if model_type != JumpStartModelType.PROPRIETARY: + other_model_id_version = self._select_version( + model_id, "*", versions_incompatible_with_sagemaker, model_type + ) # all versions here are incompatible with sagemaker + if other_model_id_version is not None: error_msg += ( f"Consider using model ID '{model_id}' with version " f"'{other_model_id_version}'." ) - else: possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0] @@ -439,14 +443,17 @@ def get_header( semantic_version_str (str): The semantic version for which to get a header. """ + return self._get_header_impl( model_id, semantic_version_str=semantic_version_str, model_type=model_type ) def _select_version( self, + model_id: str, semantic_version_str: str, available_versions: List[Version], + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> Optional[str]: """Perform semantic version search on available versions. @@ -455,6 +462,16 @@ def _select_version( available versions. available_versions (List[Version]): list of available versions. """ + + if model_type == JumpStartModelType.PROPRIETARY: + if "*" in semantic_version_str: + raise KeyError( + get_wildcard_proprietary_model_version_msg( + model_id, semantic_version_str + ) + ) + return semantic_version_str if semantic_version_str in available_versions else None + if semantic_version_str == "*": if len(available_versions) == 0: return None @@ -490,7 +507,6 @@ def _get_header_impl( versioned_model_id = self._open_source_model_id_manifest_key_cache.get( JumpStartVersionedModelId(model_id, semantic_version_str) )[0] - elif model_type == JumpStartModelType.PROPRIETARY: versioned_model_id = self._proprietary_model_id_manifest_key_cache.get( JumpStartVersionedModelId(model_id, semantic_version_str) @@ -508,7 +524,7 @@ def _get_header_impl( if attempt > 0: raise self.clear() - return self._get_header_impl(model_id, semantic_version_str, attempt + 1) + return self._get_header_impl(model_id, semantic_version_str, attempt + 1, model_type) def get_specs( self, diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index d245bf5aca..4dada409f5 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -513,14 +513,17 @@ def _validate_model_id_and_get_type_hook(): sagemaker_session=sagemaker_session, ) - if not _validate_model_id_and_get_type_hook(): + self.model_type = _validate_model_id_and_get_type_hook() + if not self.model_type: JumpStartModelsAccessor.reset_cache() - if not _validate_model_id_and_get_type_hook(): + self.model_type = _validate_model_id_and_get_type_hook() + if not self.model_type: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) estimator_init_kwargs = get_init_kwargs( model_id=model_id, model_version=model_version, + model_type=self.model_type, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, role=role, diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index c55c9081cb..23882f48d2 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -58,6 +58,16 @@ def get_wildcard_model_version_msg( ) +def get_wildcard_proprietary_model_version_msg(model_id: str, wildcard_model_version: str) -> str: + """Returns customer-facing message for passing wildcard version to proprietary models.""" + + return ( + f"Marketplace model '{model_id}' does not support " + f"wildcard version identifier '{wildcard_model_version}'. " + f"{MODEL_ID_LIST_WEB_URL} for list of supported model IDs. " + ) + + def get_old_model_version_msg( model_id: str, current_model_version: str, latest_model_version: str ) -> str: diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 7ccf57983b..0faeb92a42 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -50,7 +50,7 @@ TRAINING_ENTRY_POINT_SCRIPT_NAME, SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, ) -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.factory import model from sagemaker.jumpstart.types import ( JumpStartEstimatorDeployKwargs, @@ -77,6 +77,7 @@ def get_init_kwargs( model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -134,6 +135,7 @@ def get_init_kwargs( estimator_init_kwargs: JumpStartEstimatorInitKwargs = JumpStartEstimatorInitKwargs( model_id=model_id, model_version=model_version, + model_type=model_type, role=role, region=region, instance_count=instance_count, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index c95f1c87e6..99ff51a0d7 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -71,6 +71,7 @@ def get_default_predictor( tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, sagemaker_session: Session, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. @@ -92,6 +93,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.deserializer = deserializers.retrieve_default( model_id=model_id, @@ -100,6 +102,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.accept = accept_types.retrieve_default( model_id=model_id, @@ -108,6 +111,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.content_type = content_types.retrieve_default( model_id=model_id, @@ -116,6 +120,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) return predictor @@ -187,6 +192,7 @@ def _add_instance_type_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, training_instance_type=kwargs.training_instance_type, + model_type=kwargs.model_type, ) if not disable_instance_type_logging and orig_instance_type is None: @@ -226,6 +232,10 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets model data based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.model_data = None + return kwargs + model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve( model_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, @@ -262,6 +272,10 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets source dir based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.source_dir = None + return kwargs + source_dir = kwargs.source_dir if _model_supports_inference_script_uri( @@ -290,6 +304,10 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets entry point based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.entry_point = None + return kwargs + entry_point = kwargs.entry_point if _model_supports_inference_script_uri( @@ -311,6 +329,10 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets env based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.env = None + return kwargs + env = kwargs.env if env is None: @@ -355,6 +377,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) kwargs.model_package_arn = model_package_arn @@ -371,6 +394,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) for key, value in model_kwargs_to_add.items(): @@ -406,6 +430,7 @@ def _add_endpoint_name_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) kwargs.endpoint_name = kwargs.endpoint_name or ( @@ -427,6 +452,7 @@ def _add_model_name_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) kwargs.name = kwargs.name or ( @@ -447,6 +473,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: @@ -468,6 +495,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) for key, value in deploy_kwargs_to_add.items(): @@ -488,6 +516,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) return kwargs @@ -496,6 +525,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -528,6 +558,7 @@ def get_deploy_kwargs( deploy_kwargs: JumpStartModelDeployKwargs = JumpStartModelDeployKwargs( model_id=model_id, model_version=model_version, + model_type=model_type, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index d952c43258..e6e7cf2495 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -279,18 +279,18 @@ def _validate_model_id_and_type(): sagemaker_session=sagemaker_session, ) - self._model_type = _validate_model_id_and_type() - if not self._model_type: + self.model_type = _validate_model_id_and_type() + if not self.model_type: JumpStartModelsAccessor.reset_cache() - self._model_type = _validate_model_id_and_type() - if not self._model_type: + self.model_type = _validate_model_id_and_type() + if not self.model_type: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) self._model_data_is_set = model_data is not None model_init_kwargs = get_init_kwargs( model_id=model_id, model_from_estimator=False, - model_type=self._model_type, + model_type=self.model_type, model_version=model_version, instance_type=instance_type, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -591,6 +591,7 @@ def deploy( resources=resources, managed_instance_scaling=managed_instance_scaling, endpoint_type=endpoint_type, + model_type=self.model_type, ) predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) @@ -605,6 +606,7 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, + model_type=self.model_type, ) # If a predictor class was passed, do not mutate predictor diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 21bdfa938a..0205663771 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1137,6 +1137,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "model_type", "initial_instance_count", "instance_type", "region", @@ -1168,6 +1169,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", + "model_type", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1179,6 +1181,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -1210,6 +1213,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.model_type = model_type self.initial_instance_count = initial_instance_count self.instance_type = instance_type self.region = region @@ -1244,6 +1248,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "model_type", "instance_type", "instance_count", "region", @@ -1303,12 +1308,14 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "model_id", "model_version", + "model_type", } def __init__( self, model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, region: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, role: Optional[str] = None, @@ -1365,6 +1372,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.model_type = (model_type,) self.instance_type = instance_type self.instance_count = instance_count self.region = region @@ -1426,6 +1434,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "model_type", "region", "inputs", "wait", @@ -1440,6 +1449,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", + "model_type", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1450,6 +1460,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, region: Optional[str] = None, inputs: Optional[Union[str, Dict, Any, Any]] = None, wait: Optional[bool] = None, @@ -1464,6 +1475,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.model_type = model_type self.region = region self.inputs = inputs self.wait = wait diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 6094b8cf53..dc3568654a 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -531,6 +531,7 @@ def verify_model_region_and_return_specs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_SOURCE, ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -573,9 +574,6 @@ def verify_model_region_and_return_specs( "JumpStart models only support scopes: " f"{', '.join(constants.SUPPORTED_JUMPSTART_SCOPES)}." ) - model_type = validate_model_id_and_get_type( - model_id=model_id, region=region, model_version=version, script=scope - ) model_specs = accessors.JumpStartModelsAccessor.get_model_specs( # type: ignore region=region, diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 42c2af0917..a1e5996bd0 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -15,6 +15,7 @@ from typing import Optional from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint @@ -41,6 +42,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> Predictor: """Retrieves the default predictor for the model matching the given arguments. @@ -107,4 +109,5 @@ def retrieve_default( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index 446d034bf3..c474df57eb 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -20,6 +20,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session LOGGER = logging.getLogger("sagemaker") @@ -32,6 +33,7 @@ def retrieve_default( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default resource requirements for the model matching the given arguments. @@ -78,5 +80,6 @@ def retrieve_default( region, tolerate_vulnerable_model, tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index fc76c0fa76..ca92ff1b53 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -34,6 +34,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -93,6 +94,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> BaseSerializer: """Retrieves the default serializer for the model matching the given arguments. @@ -134,4 +136,5 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py index cbaea08a43..5faa40ccda 100644 --- a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py +++ b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py @@ -33,7 +33,7 @@ from sagemaker.jumpstart.utils import get_jumpstart_content_bucket -MAX_INIT_TIME_SECONDS = 15 +MAX_INIT_TIME_SECONDS = 5 GATED_TRAINING_MODEL_V1_SUPPORTED_REGIONS = { "us-west-2", @@ -235,3 +235,5 @@ def test_instatiating_estimator(mock_warning_logger, setup): elapsed_time = time.perf_counter() - start_time assert elapsed_time <= MAX_INIT_TIME_SECONDS + + mock_warning_logger.assert_called_once() diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 26f811322f..e3c47beb57 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -39,7 +39,7 @@ "us-east-2", } -MAX_INIT_TIME_SECONDS = 15 +MAX_INIT_TIME_SECONDS = 5 GATED_INFERENCE_MODEL_PACKAGE_SUPPORTED_REGIONS = { "us-west-2", @@ -237,6 +237,8 @@ def test_instatiating_model(mock_warning_logger, setup): assert elapsed_time <= MAX_INIT_TIME_SECONDS + mock_warning_logger.assert_called_once() + def test_jumpstart_model_register(setup): model_id = "huggingface-txt2img-conflictx-complex-lineart" @@ -254,7 +256,6 @@ def test_jumpstart_model_register(setup): predictor = model_package.deploy( instance_type="ml.p3.2xlarge", initial_instance_count=1, - tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], ) response = predictor.predict("hello world!") @@ -276,7 +277,9 @@ def test_proprietary_jumpstart_model(setup): sagemaker_session=get_sm_session(), ) - predictor = model.deploy() + predictor = model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}] + ) payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1} response = predictor.predict(payload) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 597c76e5af..737f356ef9 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -475,9 +475,6 @@ def test_proprietary_model_endpoint( mock_model_init.assert_called_once_with( image_uri="", - model_data="s3://jumpstart-cache-prod-us-west-2/None", - source_dir="s3://jumpstart-cache-prod-us-west-2/None", - entry_point="inference.py", predictor_cls=Predictor, role=execution_role, sagemaker_session=sagemaker_session, @@ -775,6 +772,7 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, + model_type=JumpStartModelType.OPEN_SOURCE, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index f75bec58df..59a47df41b 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -174,18 +174,16 @@ def test_jumpstart_cache_get_header(): model_type=JumpStartModelType.PROPRIETARY, ) - assert JumpStartModelHeader( - { - "model_id": "ai21-summarization", - "version": "1.1.003", - "min_version": "2.0.0", - "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", - "search_keywords": ["Text2Text", "Generation"], - } - ) == cache.get_header( - model_id="ai21-summarization", - semantic_version_str="*", - model_type=JumpStartModelType.PROPRIETARY, + with pytest.raises(KeyError) as e: + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert ( + "Marketplace model 'ai21-summarization' does not support wildcard version identifier '*'. " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " + "for list of supported model IDs. " in str(e.value) ) with pytest.raises(KeyError) as e: @@ -243,19 +241,6 @@ def test_jumpstart_cache_get_header(): "Did you mean to use model ID 'ai21-summarization'?" ) in str(e.value) - with pytest.raises(KeyError) as e: - cache.get_header( - model_id="ai21-summarize", - semantic_version_str="*", - model_type=JumpStartModelType.PROPRIETARY, - ) - assert ( - "Unable to find model manifest for 'ai21-summarize' with version '*'. " - "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " - "for updated list of models. " - "Did you mean to use model ID 'ai21-summarization'?" - ) in str(e.value) - with pytest.raises(KeyError): cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", @@ -793,7 +778,7 @@ def test_jumpstart_cache_proprietary_manifest_makes_correct_s3_calls( ) cache.get_header( model_id="ai21-summarization", - semantic_version_str="*", + semantic_version_str="1.1.003", model_type=JumpStartModelType.PROPRIETARY, ) @@ -819,19 +804,7 @@ def test_jumpstart_cache_proprietary_manifest_makes_correct_s3_calls( with patch("logging.Logger.warning") as mocked_warning_log: cache.get_specs( model_id="ai21-summarization", - version_str="*", - model_type=JumpStartModelType.PROPRIETARY, - ) - mocked_warning_log.assert_called_once_with( - "Using model 'ai21-summarization' with wildcard " - "version identifier '*'. You can pin to version '1.1.003' for more " - "stable results. Note that models may have different input/output " - "signatures after a major version upgrade." - ) - mocked_warning_log.reset_mock() - cache.get_specs( - model_id="ai21-summarization", - version_str="*", + version_str="1.1.003", model_type=JumpStartModelType.PROPRIETARY, ) mocked_warning_log.assert_not_called() @@ -960,14 +933,16 @@ def test_jumpstart_cache_get_specs(): model_type=JumpStartModelType.PROPRIETARY, ) - assert get_spec_from_base_spec( - model_id="ai21-summarization", - version="1.1.003", - model_type=JumpStartModelType.PROPRIETARY, - ) == cache.get_specs( - model_id="ai21-summarization", - version_str="*", - model_type=JumpStartModelType.PROPRIETARY, + with pytest.raises(KeyError) as e: + cache.get_specs( + model_id="ai21-summarization", + version_str="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert ( + "Marketplace model 'ai21-summarization' does not support wildcard version identifier '*'. " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " + "for list of supported model IDs. " in str(e.value) ) with pytest.raises(KeyError): diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 76a1b684a0..9662f8a54a 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -64,7 +64,7 @@ def test_list_jumpstart_scripts( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - assert patched_get_manifest.call_count == 3 + assert patched_get_manifest.call_count == 1 assert patched_get_model_specs.call_count == 1 patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index f2adcf5298..336025c448 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -7,7 +7,7 @@ import pytest from sagemaker.deserializers import JSONDeserializer from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import MIMEType +from sagemaker.jumpstart.enums import MIMEType, JumpStartModelType from sagemaker import predictor from sagemaker.jumpstart.model import JumpStartModel @@ -74,7 +74,10 @@ def test_proprietary_predictor_support( ) js_predictor = predictor.retrieve_default( - endpoint_name="blah", model_id=model_id, model_version=model_version + endpoint_name="blah", + model_id=model_id, + model_version=model_version, + model_type=JumpStartModelType.PROPRIETARY, ) patched_get_jumpstart_model_id_version_from_endpoint.assert_not_called() @@ -124,6 +127,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=mock_session, + model_type=JumpStartModelType.OPEN_SOURCE, ) From a96ea08e7c70f9b0fc830262bf6a382fb42f37ea Mon Sep 17 00:00:00 2001 From: Haotian An Date: Mon, 4 Mar 2024 16:39:17 +0000 Subject: [PATCH 11/30] fix: address comments --- src/sagemaker/jumpstart/cache.py | 10 ++++------ src/sagemaker/jumpstart/factory/model.py | 12 ++++++++++++ .../jumpstart/model/test_jumpstart_model.py | 1 + tests/unit/sagemaker/jumpstart/model/test_model.py | 6 ++---- tests/unit/sagemaker/jumpstart/test_accessors.py | 6 +++--- 5 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 12923a2f37..32287ff812 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -154,7 +154,7 @@ def set_manifest_file_s3_key( key: str, file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_SOURCE_MANIFEST, ) -> None: - """Set manifest file s3 key. Clears cache after new key is set. + """Set manifest file s3 key, clear cache after new key is set. Raises: ValueError: if the file type is not recognized @@ -214,8 +214,6 @@ def _model_id_retrieval_function( key (JumpStartVersionedModelId): Key for which to fetch versioned model ID. value (Optional[JumpStartVersionedModelId]): Unused variable for current value of old cached model ID/version. - UseSematicVersion (bool): boolean value to indicate whether the model versions follow - sematic versioning. Raises: KeyError: If the semantic version is not found in the manifest, or is found but @@ -300,7 +298,7 @@ def _get_open_source_manifest_key_from_model_id( key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 ) -> JumpStartVersionedModelId: - """Get open source manifest key from model id.""" + """Retrieve model manifest key for open source model, by filtering supported versions.""" return self._model_id_retrieval_function( key, value, model_type=JumpStartModelType.OPEN_SOURCE ) @@ -310,7 +308,7 @@ def _get_proprietary_manifest_key_from_model_id( key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 ) -> JumpStartVersionedModelId: - """Get proprietary manifest key from model id.""" + """Retrieve model manifest key for proprietary model, by filtering supported versions.""" return self._model_id_retrieval_function( key, value, model_type=JumpStartModelType.PROPRIETARY ) @@ -452,7 +450,7 @@ def _select_version( self, model_id: str, semantic_version_str: str, - available_versions: List[Version], + available_versions: List[str], model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> Optional[str]: """Perform semantic version search on available versions. diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 99ff51a0d7..5e69463c26 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -165,6 +165,16 @@ def _add_model_version_to_kwargs( return kwargs +def _log_model_type(kwargs: JumpStartModelInitKwargs) -> None: + """Log the model type being used""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + JUMPSTART_LOGGER.info( + "Marketplace model %s of version %s is being used.", + kwargs.model_id, + kwargs.model_version, + ) + + def _add_vulnerable_and_deprecated_status_to_kwargs( kwargs: JumpStartModelInitKwargs, ) -> JumpStartModelInitKwargs: @@ -781,4 +791,6 @@ def get_init_kwargs( model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) + _log_model_type(kwargs=model_init_kwargs) + return model_init_kwargs diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index e3c47beb57..5205765e2f 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -273,6 +273,7 @@ def test_proprietary_jumpstart_model(setup): model = JumpStartModel( model_id=model_id, + model_version="2.0.004", role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), ) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 737f356ef9..ae899183e1 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -469,9 +469,7 @@ def test_proprietary_model_endpoint( mock_session.return_value = sagemaker_session - model = JumpStartModel( - model_id=model_id, - ) + model = JumpStartModel(model_id=model_id, model_version="2.0.004") mock_model_init.assert_called_once_with( image_uri="", @@ -489,7 +487,7 @@ def test_proprietary_model_endpoint( wait=True, tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "ai21-summarization"}, - {"Key": JumpStartTag.MODEL_VERSION, "Value": "*"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "2.0.004"}, ], endpoint_logging=False, model_data_download_timeout=3600, diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index b89e28d344..c57d2a958b 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -65,7 +65,7 @@ def test_jumpstart_models_cache_get_fxs(mock_cache): @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") -def test_jumpstart_proprietary_models_cache_get_fxs(mock_cache): +def test_jumpstart_proprietary_models_cache_get(mock_cache): mock_cache.get_manifest = Mock(return_value=BASE_MANIFEST) mock_cache.get_header = Mock(side_effect=get_header_from_base_header) @@ -108,7 +108,7 @@ def test_jumpstart_proprietary_models_cache_get_fxs(mock_cache): @patch("sagemaker.jumpstart.cache.JumpStartModelsCache") -def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): +def test_jumpstart_models_cache_set_reset(mock_model_cache: Mock): # test change of region resets cache accessors.JumpStartModelsAccessor.get_model_header( @@ -183,7 +183,7 @@ def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): @patch("sagemaker.jumpstart.cache.JumpStartModelsCache") -def test_jumpstart_proprietary_models_cache_set_reset_fxs(mock_model_cache: Mock): +def test_jumpstart_proprietary_models_cache_set_reset(mock_model_cache: Mock): # test change of region resets cache accessors.JumpStartModelsAccessor.get_model_header( From e21b98be36d7a038df90ef512cf95c9108433502 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Tue, 5 Mar 2024 14:19:38 +0000 Subject: [PATCH 12/30] complete list experience and other fixes --- src/sagemaker/jumpstart/cache.py | 2 +- src/sagemaker/jumpstart/exceptions.py | 34 +++++- src/sagemaker/jumpstart/model.py | 33 +++++- src/sagemaker/jumpstart/notebook_utils.py | 31 ++++- src/sagemaker/jumpstart/types.py | 2 + tests/unit/sagemaker/jumpstart/constants.py | 1 + .../sagemaker/jumpstart/model/test_model.py | 7 +- tests/unit/sagemaker/jumpstart/test_cache.py | 2 + .../jumpstart/test_notebook_utils.py | 112 ++++++++++++------ tests/unit/sagemaker/jumpstart/utils.py | 6 +- 10 files changed, 180 insertions(+), 50 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 32287ff812..c28e82bf5e 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -465,7 +465,7 @@ def _select_version( if "*" in semantic_version_str: raise KeyError( get_wildcard_proprietary_model_version_msg( - model_id, semantic_version_str + model_id, semantic_version_str, available_versions ) ) return semantic_version_str if semantic_version_str in available_versions else None diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index 23882f48d2..cecd014b9d 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -58,14 +58,18 @@ def get_wildcard_model_version_msg( ) -def get_wildcard_proprietary_model_version_msg(model_id: str, wildcard_model_version: str) -> str: +def get_wildcard_proprietary_model_version_msg( + model_id: str, wildcard_model_version: str, available_versions: List[str] +) -> str: """Returns customer-facing message for passing wildcard version to proprietary models.""" - - return ( + msg = ( f"Marketplace model '{model_id}' does not support " f"wildcard version identifier '{wildcard_model_version}'. " - f"{MODEL_ID_LIST_WEB_URL} for list of supported model IDs. " ) + if len(available_versions) > 0: + msg += f"You can pin to version '{available_versions[0]}'. " + msg += f"{MODEL_ID_LIST_WEB_URL} for list of supported model IDs. " + return msg def get_old_model_version_msg( @@ -179,3 +183,25 @@ def __init__( ) super().__init__(self.message) + + +class MarketplaceModelSubscriptionError(ValueError): + """Exception raised when trying to deploy a JumpStart Marketplace model but the + caller is not subscribed to the product.""" + + def __init__( + self, + model_subscription_link: Optional[str] = None, + message: Optional[str] = None, + ): + if message: + self.message = message + else: + if not model_subscription_link: + raise RuntimeError("Must specify `model_subscription_link` in arguments.") + self.message = ( + f"You have not subscribed to this Marketplace model. " + f"Please subscribe following this link {model_subscription_link}" + ) + + super().__init__(self.message) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index e6e7cf2495..b55974736e 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -15,6 +15,8 @@ from __future__ import absolute_import from typing import Dict, List, Optional, Union +from botocore.exceptions import ClientError + from sagemaker import payloads from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -22,7 +24,10 @@ from sagemaker.explainer.explainer_config import ExplainerConfig from sagemaker.jumpstart.accessors import JumpStartModelsAccessor from sagemaker.jumpstart.enums import JumpStartScriptScope -from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG +from sagemaker.jumpstart.exceptions import ( + INVALID_MODEL_ID_ERROR_MSG, + MarketplaceModelSubscriptionError, +) from sagemaker.jumpstart.factory.model import ( get_default_predictor, get_deploy_kwargs, @@ -30,7 +35,10 @@ get_register_kwargs, ) from sagemaker.jumpstart.types import JumpStartSerializablePayload -from sagemaker.jumpstart.utils import validate_model_id_and_get_type +from sagemaker.jumpstart.utils import ( + validate_model_id_and_get_type, + verify_model_region_and_return_specs, +) from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model import ( Model, @@ -560,6 +568,9 @@ def deploy( endpoint. endpoint_type (EndpointType): The type of endpoint used to deploy models. (Default: EndpointType.MODEL_BASED). + + Raises: + MarketplaceModelSubscriptionError: If the caller is not subscribed to the Marketplace model. """ deploy_kwargs = get_deploy_kwargs( @@ -593,8 +604,22 @@ def deploy( endpoint_type=endpoint_type, model_type=self.model_type, ) - - predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) + try: + predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) + except ClientError as e: + error_code = e.response["Error"]["Code"] + error_message = e.response["Error"]["Message"] + if error_code == "ValidationException" and "not subscribed" in error_message: + subscription_link = verify_model_region_and_return_specs( + region=self.region, + model_id=self.model_id, + version=self.model_version, + model_type=self.model_type, + scope=JumpStartScriptScope.INFERENCE, + ).model_subscription_link + raise MarketplaceModelSubscriptionError(subscription_link) + else: + raise # If no predictor class was passed, add defaults to predictor if self.orig_predictor_cls is None and async_inference_config is None: diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 1554025995..9318562d2f 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -25,7 +25,7 @@ DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.filters import ( SPECIAL_SUPPORTED_FILTER_KEYS, BooleanValues, @@ -38,6 +38,7 @@ get_jumpstart_content_bucket, get_sagemaker_version, verify_model_region_and_return_specs, + validate_model_id_and_get_type, ) from sagemaker.session import Session @@ -246,6 +247,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin list_incomplete_models: bool = False, list_old_models: bool = False, list_versions: bool = False, + marketplace_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[Union[Tuple[str], Tuple[str, str]]]: """List models for JumpStart, and optionally apply filters to result. @@ -266,6 +268,8 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin versions should be included in the returned result. (Default: False). list_versions (bool): Optional. True if versions for models should be returned in addition to the id of the model. (Default: False). + marketplace_models (bool): Optional. True if only listing JumpStart Marketplace models. + (Default: False). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ @@ -275,6 +279,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin filter=filter, region=region, list_incomplete_models=list_incomplete_models, + marketplace_model=marketplace_model, sagemaker_session=sagemaker_session, ): if model_id not in model_id_version_dict: @@ -301,6 +306,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: str = JUMPSTART_DEFAULT_REGION_NAME, list_incomplete_models: bool = False, + marketplace_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Generator: """Generate models for JumpStart, and optionally apply filters to result. @@ -321,8 +327,20 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( - region=region, s3_client=sagemaker_session.s3_client + prop_models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, + s3_client=sagemaker_session.s3_client, + model_type=JumpStartModelType.PROPRIETARY, + ) + open_source_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, + s3_client=sagemaker_session.s3_client, + model_type=JumpStartModelType.OPEN_SOURCE, + ) + models_manifest_list = ( + prop_models_manifest_list + if marketplace_model + else (open_source_manifest_list + prop_models_manifest_list) ) if isinstance(filter, str): @@ -466,6 +484,12 @@ def get_model_url( sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to retrieve the model url. """ + model_type = validate_model_id_and_get_type( + model_id=model_id, + model_version=model_version, + region=region, + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( region=region, @@ -473,5 +497,6 @@ def get_model_url( version=model_version, sagemaker_session=sagemaker_session, scope=JumpStartScriptScope.INFERENCE, + model_type=model_type, ) return model_specs.url diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 697665fe9a..47a7e0e4e4 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -791,6 +791,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "training_instance_type_variants", "default_payloads", "gated_bucket", + "model_subscription_link", ] def __init__(self, spec: Dict[str, Any]): @@ -922,6 +923,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if json_obj.get("training_instance_type_variants") else None ) + self.model_subscription_link = json_obj.get("model_subscription_link") def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartModelSpecs object.""" diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 3dc1969ca0..ce8cc4ddfa 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -6113,6 +6113,7 @@ "deprecated_message": None, "hosting_model_package_arns": None, "hosting_eula_key": None, + "model_subscription_link": None, "hyperparameters": [ { "name": "epochs", diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index ae899183e1..53669daccb 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -36,6 +36,7 @@ get_special_model_spec, overwrite_dictionary, get_special_model_spec_for_inference_component_based_endpoint, + get_prototype_manifest, ) execution_role = "fake role! do not use!" @@ -395,7 +396,6 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, ): - mock_timestamp.return_value = "1234" mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE @@ -442,6 +442,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( ], ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -457,7 +458,11 @@ def test_proprietary_model_endpoint( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_manifest: mock.Mock, ): + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) mock_model_deploy.return_value = default_predictor mock_sagemaker_timestamp.return_value = "7777" diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 59a47df41b..e3adef0c26 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -182,6 +182,7 @@ def test_jumpstart_cache_get_header(): ) assert ( "Marketplace model 'ai21-summarization' does not support wildcard version identifier '*'. " + "You can pin to version '1.1.003'. " "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " "for list of supported model IDs. " in str(e.value) ) @@ -941,6 +942,7 @@ def test_jumpstart_cache_get_specs(): ) assert ( "Marketplace model 'ai21-summarization' does not support wildcard version identifier '*'. " + "You can pin to version '1.1.003'. " "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " "for list of supported model IDs. " in str(e.value) ) diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 9662f8a54a..168b69d704 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -64,7 +64,7 @@ def test_list_jumpstart_scripts( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - assert patched_get_manifest.call_count == 1 + assert patched_get_manifest.call_count == 2 assert patched_get_model_specs.call_count == 1 patched_get_model_specs.reset_mock() @@ -79,8 +79,8 @@ def test_list_jumpstart_scripts( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() - assert patched_read_s3_file.call_count == len(PROTOTYPICAL_MODEL_SPECS_DICT) + assert patched_get_manifest.call_count == 2 + assert patched_read_s3_file.call_count == 2 * len(PROTOTYPICAL_MODEL_SPECS_DICT) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -108,7 +108,7 @@ def test_list_jumpstart_tasks( ) # incomplete list, based on mocked metadata patched_generate_jumpstart_models.assert_called_once() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_model_specs.assert_not_called() patched_get_model_specs.reset_mock() @@ -123,7 +123,7 @@ def test_list_jumpstart_tasks( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_model_specs.assert_not_called() @@ -155,11 +155,11 @@ def test_list_jumpstart_frameworks( ) patched_generate_jumpstart_models.assert_called_once() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_model_specs.assert_not_called() patched_get_model_specs.reset_mock() - patched_get_manifest.reset_mock() + assert patched_get_manifest.call_count == 2 patched_generate_jumpstart_models.reset_mock() kwargs = { @@ -181,7 +181,7 @@ def test_list_jumpstart_frameworks( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 4 patched_get_model_specs.assert_not_called() @@ -239,7 +239,7 @@ def test_list_jumpstart_models_script_filter( kwargs = {"filter": f"training_supported != {val}"} list_jumpstart_models(**kwargs) assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -256,7 +256,7 @@ def test_list_jumpstart_models_script_filter( ("xgboost-classification-model", "1.0.0"), ] assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -265,7 +265,7 @@ def test_list_jumpstart_models_script_filter( models = list_jumpstart_models(**kwargs) assert [] == models assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -288,7 +288,7 @@ def test_list_jumpstart_models_task_filter( kwargs = {"filter": f"task == {val}"} list_jumpstart_models(**kwargs) patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -296,7 +296,7 @@ def test_list_jumpstart_models_task_filter( kwargs = {"filter": f"task != {val}"} list_jumpstart_models(**kwargs) patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -313,7 +313,7 @@ def test_list_jumpstart_models_task_filter( ("xgboost-classification-model", "1.0.0"), ] patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -322,7 +322,7 @@ def test_list_jumpstart_models_task_filter( models = list_jumpstart_models(**kwargs) assert [] == models patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") @@ -349,7 +349,7 @@ def test_list_jumpstart_models_framework_filter( kwargs = {"filter": f"framework == {val}"} list_jumpstart_models(**kwargs) patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -357,7 +357,7 @@ def test_list_jumpstart_models_framework_filter( kwargs = {"filter": f"framework != {val}"} list_jumpstart_models(**kwargs) patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -373,7 +373,7 @@ def test_list_jumpstart_models_framework_filter( ("xgboost-classification-model", "1.0.0"), ] patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -391,8 +391,8 @@ def test_list_jumpstart_models_framework_filter( } models = list_jumpstart_models(**kwargs) assert [("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0")] == models - patched_read_s3_file.assert_called_once() - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -406,7 +406,7 @@ def test_list_jumpstart_models_framework_filter( models = list_jumpstart_models(**kwargs) assert [] == models patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -421,8 +421,10 @@ def test_list_jumpstart_models_region( list_jumpstart_models(region="some-region") - patched_get_manifest.assert_called_once_with( - region="some-region", s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client + patched_get_manifest.assert_called_with( + region="some-region", + s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + model_type=JumpStartModelType.OPEN_SOURCE, ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -479,7 +481,7 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): ] == list_jumpstart_models(list_old_models=True, list_versions=True) patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -527,8 +529,8 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): And("inference_vulnerable is false", "training_vulnerable is false") ) - assert patched_read_s3_file.call_count == num_specs - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -539,8 +541,8 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): And("inference_vulnerable is false", "training_vulnerable is false") ) - assert patched_read_s3_file.call_count == num_specs - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -571,8 +573,8 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str: num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) assert [] == list_jumpstart_models("deprecated equals false") - assert patched_read_s3_file.call_count == num_specs - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -608,6 +610,41 @@ def test_list_jumpstart_models_no_versions( assert list_jumpstart_models(list_versions=False) == all_model_ids + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + def test_list_jumpstart_proprietary_models( + self, + patched_get_model_specs: Mock, + patched_get_manifest: Mock, + ): + patched_get_model_specs.side_effect = get_prototype_model_spec + patched_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + all_prop_model_ids = [ + "ai21-paraphrase", + "ai21-summarization", + "lighton-mini-instruct40b", + ] + + all_open_source_model_ids = [ + "catboost-classification-model", + "huggingface-spc-bert-base-cased", + "lightgbm-classification-model", + "mxnet-semseg-fcn-resnet50-ade", + "pytorch-eqa-bert-base-cased", + "sklearn-classification-linear", + "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", + "xgboost-classification-model", + ] + + assert list_jumpstart_models(marketplace_model=True) == all_prop_model_ids + + assert list_jumpstart_models(list_versions=False) == sorted( + all_prop_model_ids + all_open_source_model_ids + ) + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_complex_queries( @@ -671,12 +708,20 @@ def test_list_jumpstart_models_multiple_level_index( list_jumpstart_models("hosting_ecr_specs.py_version == py3") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_get_model_url(patched_get_model_specs: Mock, patched_validate_model_id_and_get_type: Mock): +def test_get_model_url( + patched_get_model_specs: Mock, + patched_validate_model_id_and_get_type: Mock, + patched_get_manifest: Mock, +): patched_get_model_specs.side_effect = get_prototype_model_spec patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) model_id, version = "xgboost-classification-model", "1.0.0" assert "https://xgboost.readthedocs.io/en/latest/" == get_model_url(model_id, version) @@ -687,7 +732,6 @@ def test_get_model_url(patched_get_model_specs: Mock, patched_validate_model_id_ ) model_id, version = "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0" - region = "fake-region" patched_get_model_specs.reset_mock() patched_get_model_specs.side_effect = lambda *largs, **kwargs: get_prototype_model_spec( @@ -696,12 +740,12 @@ def test_get_model_url(patched_get_model_specs: Mock, patched_validate_model_id_ **{key: value for key, value in kwargs.items() if key != "region"}, ) - get_model_url(model_id, version, region=region) + get_model_url(model_id, version, region="us-west-2") patched_get_model_specs.assert_called_once_with( model_id=model_id, version=version, - region=region, + region="us-west-2", s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, model_type=JumpStartModelType.OPEN_SOURCE, ) diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index f8c6384e81..cadc54c96d 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -56,9 +56,6 @@ def get_header_from_base_header( if model_type == JumpStartModelType.PROPRIETARY: spec = copy.deepcopy(BASE_PROPRIETARY_HEADER) - spec["version"] = version or semantic_version_str - spec["model_id"] = model_id - return JumpStartModelHeader(spec) if all( @@ -91,7 +88,10 @@ def get_header_from_base_header( def get_prototype_manifest( region: str = JUMPSTART_DEFAULT_REGION_NAME, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> List[JumpStartModelHeader]: + if model_type == JumpStartModelType.PROPRIETARY: + return [JumpStartModelHeader(spec) for spec in BASE_PROPRIETARY_MANIFEST] return [ get_header_from_base_header(region=region, model_id=model_id, version=version) for model_id in PROTOTYPICAL_MODEL_SPECS_DICT.keys() From feef2a2e2e9108d22b8faa07171e6014b85b7a5e Mon Sep 17 00:00:00 2001 From: Haotian An Date: Tue, 5 Mar 2024 14:41:58 +0000 Subject: [PATCH 13/30] fix: pylint --- src/sagemaker/jumpstart/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index b55974736e..93f2ae0fc2 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -570,7 +570,7 @@ def deploy( (Default: EndpointType.MODEL_BASED). Raises: - MarketplaceModelSubscriptionError: If the caller is not subscribed to the Marketplace model. + MarketplaceModelSubscriptionError: If the caller is not subscribed to the model. """ deploy_kwargs = get_deploy_kwargs( @@ -618,8 +618,7 @@ def deploy( scope=JumpStartScriptScope.INFERENCE, ).model_subscription_link raise MarketplaceModelSubscriptionError(subscription_link) - else: - raise + raise # If no predictor class was passed, add defaults to predictor if self.orig_predictor_cls is None and async_inference_config is None: From 33a2b59a06185e37807669736c9f2dc022c6f148 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Tue, 5 Mar 2024 15:49:58 +0000 Subject: [PATCH 14/30] add doc utils and fix pylint --- doc/doc_utils/jumpstart_doc_utils.py | 57 +++++++++++++++++++++++---- src/sagemaker/jumpstart/exceptions.py | 5 ++- 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 348de7adeb..d1e87b1740 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -74,6 +74,7 @@ class Frameworks(str, Enum): JUMPSTART_REGION = "eu-west-2" SDK_MANIFEST_FILE = "models_manifest.json" +PROPRIETARY_SDK_MANIFEST_FILE = "proprietary-sdk-manifest.json" JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format( JUMPSTART_REGION, JUMPSTART_REGION ) @@ -159,6 +160,13 @@ def get_jumpstart_sdk_manifest(): return json.loads(models_manifest) +def get_proprietary_sdk_manifest(): + url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, PROPRIETARY_SDK_MANIFEST_FILE) + with request.urlopen(url) as f: + models_manifest = f.read().decode("utf-8") + return json.loads(models_manifest) + + def get_jumpstart_sdk_spec(key): url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, key) with request.urlopen(url) as f: @@ -196,6 +204,36 @@ def get_model_source(url): return "Source" +def create_marketplace_model_table(): + sdk_manifest = get_proprietary_sdk_manifest() + + marketpkace_content_intro = [] + marketpkace_content_intro.append("\n") + marketpkace_content_intro.append(".. list-table:: Available Models\n") + marketpkace_content_intro.append(" :widths: 50 20 20 20 20\n") + marketpkace_content_intro.append(" :header-rows: 1\n") + marketpkace_content_intro.append(" :class: datatable\n") + marketpkace_content_intro.append("\n") + marketpkace_content_intro.append(" * - Model ID\n") + marketpkace_content_intro.append(" - Fine Tunable?\n") + marketpkace_content_intro.append(" - Supported Version\n") + marketpkace_content_intro.append(" - Min SDK Version\n") + marketpkace_content_intro.append(" - Source\n") + + marketplace_content_entries = [] + for model in sdk_manifest: + model_spec = get_jumpstart_sdk_spec(model["spec_key"]) + model_source = get_model_source(model_spec["url"]) + marketplace_content_entries.append(" * - {}\n".format(model_spec["model_id"])) + marketplace_content_entries.append(" - {}\n".format(False)) # TODO: support training + marketplace_content_entries.append(" - {}\n".format(model["version"])) + marketplace_content_entries.append(" - {}\n".format(model["min_version"])) + marketplace_content_entries.append( + " - `{} <{}>`__ |external-link|\n".format(model_source, model_spec.get("url")) + ) + return marketpkace_content_intro + marketplace_content_entries + ["\n"] + + def create_jumpstart_model_table(): sdk_manifest = get_jumpstart_sdk_manifest() sdk_manifest_top_versions_for_models = {} @@ -249,19 +287,19 @@ def create_jumpstart_model_table(): file_content_intro.append(" - Source\n") dynamic_table_files = [] - file_content_entries = [] + open_source_content_entries = [] for model in sdk_manifest_top_versions_for_models.values(): model_spec = get_jumpstart_sdk_spec(model["spec_key"]) model_task = get_model_task(model_spec["model_id"]) string_model_task = get_string_model_task(model_spec["model_id"]) model_source = get_model_source(model_spec["url"]) - file_content_entries.append(" * - {}\n".format(model_spec["model_id"])) - file_content_entries.append(" - {}\n".format(model_spec["training_supported"])) - file_content_entries.append(" - {}\n".format(model["version"])) - file_content_entries.append(" - {}\n".format(model["min_version"])) - file_content_entries.append(" - {}\n".format(model_task)) - file_content_entries.append( + open_source_content_entries.append(" * - {}\n".format(model_spec["model_id"])) + open_source_content_entries.append(" - {}\n".format(model_spec["training_supported"])) + open_source_content_entries.append(" - {}\n".format(model["version"])) + open_source_content_entries.append(" - {}\n".format(model["min_version"])) + open_source_content_entries.append(" - {}\n".format(model_task)) + open_source_content_entries.append( " - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"]) ) @@ -299,7 +337,10 @@ def create_jumpstart_model_table(): f.writelines(file_content_single_entry) f.close() + marketplace_content_entries = create_marketplace_model_table() + f = open("doc_utils/pretrainedmodels.rst", "a") f.writelines(file_content_intro) - f.writelines(file_content_entries) + f.writelines(open_source_content_entries) + f.writelines(marketplace_content_entries) f.close() diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index cecd014b9d..31f0d8b21b 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -187,7 +187,10 @@ def __init__( class MarketplaceModelSubscriptionError(ValueError): """Exception raised when trying to deploy a JumpStart Marketplace model but the - caller is not subscribed to the product.""" + caller is not subscribed to the product. + + A caller is required to subscribe to the Marketplace product in order to deploy. + """ def __init__( self, From 8c906413d89c199cc45c32f589c56a83f6382faf Mon Sep 17 00:00:00 2001 From: Haotian An Date: Tue, 5 Mar 2024 16:53:07 +0000 Subject: [PATCH 15/30] fix: docstyle --- doc/doc_utils/jumpstart_doc_utils.py | 2 +- src/sagemaker/jumpstart/exceptions.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index d1e87b1740..9d4c3809c7 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -223,7 +223,7 @@ def create_marketplace_model_table(): marketplace_content_entries = [] for model in sdk_manifest: model_spec = get_jumpstart_sdk_spec(model["spec_key"]) - model_source = get_model_source(model_spec["url"]) + model_source = get_model_source(model_spec.get("url")) marketplace_content_entries.append(" * - {}\n".format(model_spec["model_id"])) marketplace_content_entries.append(" - {}\n".format(False)) # TODO: support training marketplace_content_entries.append(" - {}\n".format(model["version"])) diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index 31f0d8b21b..98243cf96a 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -186,10 +186,11 @@ def __init__( class MarketplaceModelSubscriptionError(ValueError): - """Exception raised when trying to deploy a JumpStart Marketplace model but the - caller is not subscribed to the product. + """Exception raised when trying to deploy a JumpStart Marketplace model. A caller is required to subscribe to the Marketplace product in order to deploy. + This exception is raised when a caller tries to deploy a JumpStart Marketplace + but the caller is not subscribed to the model. """ def __init__( From 5bccb8e21a81bcbf92340e02fa494e0a5f540dac Mon Sep 17 00:00:00 2001 From: Haotian An Date: Tue, 5 Mar 2024 17:41:46 +0000 Subject: [PATCH 16/30] fix: doc --- doc/doc_utils/jumpstart_doc_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 9d4c3809c7..3722703263 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -223,13 +223,12 @@ def create_marketplace_model_table(): marketplace_content_entries = [] for model in sdk_manifest: model_spec = get_jumpstart_sdk_spec(model["spec_key"]) - model_source = get_model_source(model_spec.get("url")) marketplace_content_entries.append(" * - {}\n".format(model_spec["model_id"])) marketplace_content_entries.append(" - {}\n".format(False)) # TODO: support training marketplace_content_entries.append(" - {}\n".format(model["version"])) marketplace_content_entries.append(" - {}\n".format(model["min_version"])) marketplace_content_entries.append( - " - `{} <{}>`__ |external-link|\n".format(model_source, model_spec.get("url")) + " - `{} <{}>`__ |external-link|\n".format("Source", model_spec.get("url")) ) return marketpkace_content_intro + marketplace_content_entries + ["\n"] From f8258b7d4f0048250e352075838429d431ea1f0b Mon Sep 17 00:00:00 2001 From: Haotian An Date: Tue, 5 Mar 2024 19:04:48 +0000 Subject: [PATCH 17/30] fix: default payloads --- src/sagemaker/jumpstart/artifacts/payloads.py | 3 +++ src/sagemaker/jumpstart/model.py | 1 + src/sagemaker/payloads.py | 7 +++++++ .../jumpstart/model/test_jumpstart_model.py | 14 ++++++++++++++ 4 files changed, 25 insertions(+) diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 3ea2c16f80..b7d51cd995 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -20,6 +20,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( @@ -35,6 +36,7 @@ def _retrieve_example_payloads( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, ) -> Optional[Dict[str, JumpStartSerializablePayload]]: """Returns example payloads. @@ -72,6 +74,7 @@ def _retrieve_example_payloads( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) default_payloads = model_specs.default_payloads diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 93f2ae0fc2..bc06a56051 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -374,6 +374,7 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: return payloads.retrieve_example( model_id=self.model_id, model_version=self.model_version, + model_type=self.model_type, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index 52d633ed4e..89286e217c 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -21,6 +21,7 @@ from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.payload_utils import PayloadSerializer from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -31,6 +32,7 @@ def retrieve_all_examples( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, serialize: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -83,6 +85,7 @@ def retrieve_all_examples( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if unserialized_payload_dict is None: @@ -120,6 +123,7 @@ def retrieve_example( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, serialize: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -133,6 +137,8 @@ def retrieve_example( the model payload. model_version (str): The version of the JumpStart model for which to retrieve the model payload. + model_type (str): The model type of the JumpStart model, either is open source + or marketplace (proprietary). serialize (bool): Whether to serialize byte-stream valued payloads by downloading binary files from s3 and applying encoding, or to keep payload in pre-serialized state. Set this option to False if you want to avoid s3 downloads or if you @@ -162,6 +168,7 @@ def retrieve_example( region=region, model_id=model_id, model_version=model_version, + model_type=model_type, serialize=serialize, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 5205765e2f..9f56a67b68 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -286,3 +286,17 @@ def test_proprietary_jumpstart_model(setup): response = predictor.predict(payload) assert response is not None + + +def test_jumpstart_payload(setup): + model_id = "ai21-jurassic-2-light" + + model = JumpStartModel( + model_id=model_id, + model_version="2.0.004", + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + resp = model.retrieve_example_payload() + print(resp) \ No newline at end of file From 896a2cf15c680cd8501ae1f2ee1a9f9b8222742d Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 6 Mar 2024 18:54:13 +0000 Subject: [PATCH 18/30] fix: doc and tags and enums --- doc/doc_utils/jumpstart_doc_utils.py | 16 ++++++++++--- src/sagemaker/jumpstart/cache.py | 14 ++++++++++- src/sagemaker/jumpstart/enums.py | 3 +++ src/sagemaker/jumpstart/exceptions.py | 13 ++++++++++ src/sagemaker/jumpstart/factory/model.py | 2 +- src/sagemaker/jumpstart/model.py | 24 +++++++++++++++++++ src/sagemaker/jumpstart/utils.py | 9 +++++++ .../sagemaker/jumpstart/model/test_model.py | 1 + 8 files changed, 77 insertions(+), 5 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 3722703263..0a35dcfa0c 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -205,8 +205,6 @@ def get_model_source(url): def create_marketplace_model_table(): - sdk_manifest = get_proprietary_sdk_manifest() - marketpkace_content_intro = [] marketpkace_content_intro.append("\n") marketpkace_content_intro.append(".. list-table:: Available Models\n") @@ -220,8 +218,20 @@ def create_marketplace_model_table(): marketpkace_content_intro.append(" - Min SDK Version\n") marketpkace_content_intro.append(" - Source\n") - marketplace_content_entries = [] + sdk_manifest = get_proprietary_sdk_manifest() + sdk_manifest_top_versions_for_models = {} + for model in sdk_manifest: + if model["model_id"] not in sdk_manifest_top_versions_for_models: + sdk_manifest_top_versions_for_models[model["model_id"]] = model + else: + if Version( + sdk_manifest_top_versions_for_models[model["model_id"]]["version"] + ) < Version(model["version"]): + sdk_manifest_top_versions_for_models[model["model_id"]] = model + + marketplace_content_entries = [] + for model in sdk_manifest_top_versions_for_models.values(): model_spec = get_jumpstart_sdk_spec(model["spec_key"]) marketplace_content_entries.append(" * - {}\n".format(model_spec["model_id"])) marketplace_content_entries.append(" - {}\n".format(False)) # TODO: support training diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index c28e82bf5e..44b7a3f9ae 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -214,6 +214,8 @@ def _model_id_retrieval_function( key (JumpStartVersionedModelId): Key for which to fetch versioned model ID. value (Optional[JumpStartVersionedModelId]): Unused variable for current value of old cached model ID/version. + model_type (JumpStartModelType): JumpStart model type to indicate whether it is + open weights model or proprietary (Marketplace) model. Raises: KeyError: If the semantic version is not found in the manifest, or is found but @@ -276,10 +278,20 @@ def _model_id_retrieval_function( ) other_model_id_version = None - if model_type != JumpStartModelType.PROPRIETARY: + if model_type == JumpStartModelType.OPEN_SOURCE: other_model_id_version = self._select_version( model_id, "*", versions_incompatible_with_sagemaker, model_type ) # all versions here are incompatible with sagemaker + elif model_type == JumpStartModelType.PROPRIETARY: + all_possible_model_id_version = [ + header.version for header in manifest.values() # type: ignore + if header.model_id == model_id + ] + other_model_id_version = ( + None + if not all_possible_model_id_version + else all_possible_model_id_version[0] + ) if other_model_id_version is not None: error_msg += ( diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index 2188294f06..148790d66c 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -89,6 +89,9 @@ class JumpStartTag(str, Enum): MODEL_ID = "sagemaker-sdk:jumpstart-model-id" MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" + MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" + + MARKETPLACE_MODEL_TYPE_VALUE = "SageMakerJumpStartMarketplace" class SerializerType(str, Enum): diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index 98243cf96a..d4771d8bc4 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -58,6 +58,19 @@ def get_wildcard_model_version_msg( ) +def get_proprietary_model_subscription_msg( + model_id: str, + model_version: str, + subscription_link: str, +) -> str: + """Returns customer-facing message for using a Marketplace model.""" + + return ( + f"Using Marketplace model '{model_id}' with version identifier '{model_version}'. " + f"Please make sure to subscribe to the model from {subscription_link}" + ) + + def get_wildcard_proprietary_model_version_msg( model_id: str, wildcard_model_version: str, available_versions: List[str] ) -> str: diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 1219491654..60ed79c9ad 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -488,7 +488,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: if kwargs.sagemaker_session.settings.include_jumpstart_tags: kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version + kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type ) return kwargs diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index bc06a56051..e706771e29 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -27,6 +27,7 @@ from sagemaker.jumpstart.exceptions import ( INVALID_MODEL_ID_ERROR_MSG, MarketplaceModelSubscriptionError, + get_proprietary_model_subscription_msg, ) from sagemaker.jumpstart.factory.model import ( get_default_predictor, @@ -39,6 +40,8 @@ validate_model_id_and_get_type, verify_model_region_and_return_specs, ) +from sagemaker.jumpstart.constants import JUMPSTART_LOGGER +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model import ( Model, @@ -336,10 +339,29 @@ def _validate_model_id_and_type(): self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session + if self.model_type == JumpStartModelType.PROPRIETARY: + self.log_subscription_warning() + super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) self.model_package_arn = model_init_kwargs.model_package_arn + def log_subscription_warning(self) -> None: + """Logs customer facing message for subscribe to the proprietary model.""" + subscription_link = verify_model_region_and_return_specs( + region=self.region, + model_id=self.model_id, + version=self.model_version, + model_type=self.model_type, + scope=JumpStartScriptScope.INFERENCE, + sagemaker_session=self.sagemaker_session, + ).model_subscription_link + JUMPSTART_LOGGER.warning( + get_proprietary_model_subscription_msg( + self.model_id, self.model_version, subscription_link + ) + ) + def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: """Returns all example payloads associated with the model. @@ -357,6 +379,7 @@ def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, + model_type=self.model_type, ) def retrieve_example_payload(self) -> JumpStartSerializablePayload: @@ -617,6 +640,7 @@ def deploy( version=self.model_version, model_type=self.model_type, scope=JumpStartScriptScope.INFERENCE, + sagemaker_session=self.sagemaker_session, ).model_subscription_link raise MarketplaceModelSubscriptionError(subscription_link) raise diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index dc3568654a..a9e2555ec7 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -315,6 +315,7 @@ def add_single_jumpstart_tag( ( tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags) + or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags) ) if is_uri else False @@ -349,6 +350,7 @@ def add_jumpstart_model_id_version_tags( tags: Optional[List[TagsDict]], model_id: str, model_version: str, + model_type: Optional[enums.JumpStartModelType] = None, ) -> List[TagsDict]: """Add custom model ID and version tags to JumpStart related resources.""" if model_id is None or model_version is None: @@ -365,6 +367,13 @@ def add_jumpstart_model_id_version_tags( tags, is_uri=False, ) + if model_type == enums.JumpStartModelType.PROPRIETARY: + tags = add_single_jumpstart_tag( + enums.JumpStartTag.MARKETPLACE_MODEL_TYPE_VALUE, + enums.JumpStartTag.MODEL_TYPE, + tags, + is_uri=False, + ) return tags diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 53669daccb..e51545adb8 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -493,6 +493,7 @@ def test_proprietary_model_endpoint( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "ai21-summarization"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "2.0.004"}, + {"Key": JumpStartTag.MODEL_TYPE, "Value": "SageMakerJumpStartMarketplace"}, ], endpoint_logging=False, model_data_download_timeout=3600, From d701211b2104a3fc8149e6c59285ee28833edac1 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Wed, 6 Mar 2024 19:59:44 +0000 Subject: [PATCH 19/30] fix: jumpstart doc --- doc/doc_utils/jumpstart_doc_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 0a35dcfa0c..8595fb9729 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -78,6 +78,8 @@ class Frameworks(str, Enum): JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format( JUMPSTART_REGION, JUMPSTART_REGION ) +PROPRIETARY_DOC_BUCKET = "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com" + TASK_MAP = { Tasks.IC: ProblemTypes.IMAGE_CLASSIFICATION, Tasks.IC_EMBEDDING: ProblemTypes.IMAGE_EMBEDDING, @@ -161,14 +163,14 @@ def get_jumpstart_sdk_manifest(): def get_proprietary_sdk_manifest(): - url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, PROPRIETARY_SDK_MANIFEST_FILE) + url = "{}/{}".format(PROPRIETARY_DOC_BUCKET, PROPRIETARY_SDK_MANIFEST_FILE) with request.urlopen(url) as f: models_manifest = f.read().decode("utf-8") return json.loads(models_manifest) def get_jumpstart_sdk_spec(key): - url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, key) + url = "{}/{}".format(PROPRIETARY_DOC_BUCKET, key) with request.urlopen(url) as f: model_spec = f.read().decode("utf-8") return json.loads(model_spec) From e3e64ba5a9007b11f69eb2859dac5080bb06f96e Mon Sep 17 00:00:00 2001 From: Haotian An Date: Fri, 8 Mar 2024 14:12:52 +0000 Subject: [PATCH 20/30] rename to open_weights and fix filtering --- doc/doc_utils/jumpstart_doc_utils.py | 16 ++-- src/sagemaker/accept_types.py | 2 +- src/sagemaker/content_types.py | 2 +- src/sagemaker/deserializers.py | 2 +- src/sagemaker/instance_types.py | 2 +- src/sagemaker/jumpstart/accessors.py | 6 +- .../jumpstart/artifacts/instance_types.py | 2 +- src/sagemaker/jumpstart/artifacts/kwargs.py | 7 +- .../jumpstart/artifacts/model_packages.py | 2 +- src/sagemaker/jumpstart/artifacts/payloads.py | 2 +- .../jumpstart/artifacts/predictors.py | 14 ++-- .../jumpstart/artifacts/resource_names.py | 2 +- .../artifacts/resource_requirements.py | 2 +- src/sagemaker/jumpstart/cache.py | 74 +++++++++---------- src/sagemaker/jumpstart/constants.py | 4 +- src/sagemaker/jumpstart/enums.py | 2 +- src/sagemaker/jumpstart/exceptions.py | 3 +- src/sagemaker/jumpstart/factory/estimator.py | 2 +- src/sagemaker/jumpstart/factory/model.py | 10 +-- src/sagemaker/jumpstart/filters.py | 1 + src/sagemaker/jumpstart/model.py | 4 +- src/sagemaker/jumpstart/notebook_utils.py | 41 ++++++---- src/sagemaker/jumpstart/types.py | 12 +-- src/sagemaker/jumpstart/utils.py | 14 ++-- src/sagemaker/model.py | 2 +- src/sagemaker/payloads.py | 4 +- src/sagemaker/predictor.py | 2 +- src/sagemaker/resource_requirements.py | 2 +- src/sagemaker/serializers.py | 2 +- .../jumpstart/model/test_jumpstart_model.py | 18 +---- .../jumpstart/test_accept_types.py | 8 +- .../jumpstart/test_content_types.py | 8 +- .../jumpstart/test_deserializers.py | 8 +- .../jumpstart/test_default.py | 12 +-- .../hyperparameters/jumpstart/test_default.py | 8 +- .../jumpstart/test_validate.py | 12 +-- .../image_uris/jumpstart/test_common.py | 10 +-- .../jumpstart/test_instance_types.py | 10 +-- .../jumpstart/estimator/test_estimator.py | 46 ++++++------ .../estimator/test_sagemaker_config.py | 16 ++-- .../sagemaker/jumpstart/model/test_model.py | 51 +++++++------ .../jumpstart/model/test_sagemaker_config.py | 16 ++-- .../sagemaker/jumpstart/test_artifacts.py | 4 +- tests/unit/sagemaker/jumpstart/test_cache.py | 28 +++---- .../jumpstart/test_notebook_utils.py | 13 ++-- .../sagemaker/jumpstart/test_predictor.py | 2 +- tests/unit/sagemaker/jumpstart/utils.py | 16 ++-- .../jumpstart/test_default.py | 6 +- .../model_uris/jumpstart/test_common.py | 10 +-- .../jumpstart/test_resource_requirements.py | 8 +- .../script_uris/jumpstart/test_common.py | 10 +-- .../serializers/jumpstart/test_serializers.py | 8 +- 52 files changed, 282 insertions(+), 286 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 8595fb9729..af078888b7 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -298,19 +298,19 @@ def create_jumpstart_model_table(): file_content_intro.append(" - Source\n") dynamic_table_files = [] - open_source_content_entries = [] + open_weight_content_entries = [] for model in sdk_manifest_top_versions_for_models.values(): model_spec = get_jumpstart_sdk_spec(model["spec_key"]) model_task = get_model_task(model_spec["model_id"]) string_model_task = get_string_model_task(model_spec["model_id"]) model_source = get_model_source(model_spec["url"]) - open_source_content_entries.append(" * - {}\n".format(model_spec["model_id"])) - open_source_content_entries.append(" - {}\n".format(model_spec["training_supported"])) - open_source_content_entries.append(" - {}\n".format(model["version"])) - open_source_content_entries.append(" - {}\n".format(model["min_version"])) - open_source_content_entries.append(" - {}\n".format(model_task)) - open_source_content_entries.append( + open_weight_content_entries.append(" * - {}\n".format(model_spec["model_id"])) + open_weight_content_entries.append(" - {}\n".format(model_spec["training_supported"])) + open_weight_content_entries.append(" - {}\n".format(model["version"])) + open_weight_content_entries.append(" - {}\n".format(model["min_version"])) + open_weight_content_entries.append(" - {}\n".format(model_task)) + open_weight_content_entries.append( " - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"]) ) @@ -352,6 +352,6 @@ def create_jumpstart_model_table(): f = open("doc_utils/pretrainedmodels.rst", "a") f.writelines(file_content_intro) - f.writelines(open_source_content_entries) + f.writelines(open_weight_content_entries) f.writelines(marketplace_content_entries) f.close() diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 86b78dbbfc..43abd5d1a1 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -76,7 +76,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> str: """Retrieves the default accept type for the model matching the given arguments. diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 6bbcff876c..efdbf6846c 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -76,7 +76,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> str: """Retrieves the default content type for the model matching the given arguments. diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 7bd8315d03..27ae946450 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -96,7 +96,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 414b28fee2..c24fd57bc7 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -35,7 +35,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> str: """Retrieves the default instance type for the model matching the given arguments. diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 482cfdeee7..d49bcb9bb8 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -200,7 +200,7 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None: def _get_manifest( region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest. @@ -229,7 +229,7 @@ def get_model_header( region: str, model_id: str, version: str, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> JumpStartModelHeader: """Returns model header from JumpStart models cache. @@ -254,7 +254,7 @@ def get_model_specs( model_id: str, version: str, s3_client: Optional[boto3.client] = None, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 47b7849dc5..dfe21f21a9 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -39,7 +39,7 @@ def _retrieve_default_instance_type( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> str: """Retrieves the default instance type for the model. diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 3a7012d4f9..f057864cdd 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -36,7 +36,7 @@ def _retrieve_model_init_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> dict: """Retrieves kwargs for `Model`. @@ -81,6 +81,9 @@ def _retrieve_model_init_kwargs( if model_specs.inference_enable_network_isolation is not None: kwargs.update({"enable_network_isolation": model_specs.inference_enable_network_isolation}) + if model_type == JumpStartModelType.PROPRIETARY: + kwargs.update({"enable_network_isolation": True}) + return kwargs @@ -92,7 +95,7 @@ def _retrieve_model_deploy_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> dict: """Retrieves kwargs for `Model.deploy`. diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index 89c8e8ebd0..c87088f4fb 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -36,7 +36,7 @@ def _retrieve_model_package_arn( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> Optional[str]: """Retrieves associated model pacakge arn for the model. diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index b7d51cd995..6db511f4db 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -36,7 +36,7 @@ def _retrieve_example_payloads( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> Optional[Dict[str, JumpStartSerializablePayload]]: """Returns example payloads. diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 4f47fee0be..c16ba4eaac 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -77,7 +77,7 @@ def _retrieve_default_deserializer( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> BaseDeserializer: """Retrieves the default deserializer for the model. @@ -123,7 +123,7 @@ def _retrieve_default_serializer( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> BaseSerializer: """Retrieves the default serializer for the model. @@ -168,7 +168,7 @@ def _retrieve_deserializer_options( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model. @@ -283,7 +283,7 @@ def _retrieve_default_content_type( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default content type for the model. @@ -334,7 +334,7 @@ def _retrieve_default_accept_type( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> str: """Retrieves the default accept type for the model. @@ -385,7 +385,7 @@ def _retrieve_supported_accept_types( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> List[str]: """Retrieves the supported accept types for the model. @@ -436,7 +436,7 @@ def _retrieve_supported_content_types( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> List[str]: """Retrieves the supported content types for the model. diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index a46191be95..eae4c1a300 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -33,7 +33,7 @@ def _retrieve_resource_name_base( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns default resource name. diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index c7c7ef8561..5464c30937 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -51,7 +51,7 @@ def _retrieve_default_resources( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, ) -> ResourceRequirements: diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 44b7a3f9ae..a4514a97a9 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -108,12 +108,12 @@ def __init__( expiration_horizon=s3_cache_expiration_horizon, retrieval_function=self._retrieval_function, ) - self._open_source_model_id_manifest_key_cache = LRUCache[ + self._open_weight_model_id_manifest_key_cache = LRUCache[ JumpStartVersionedModelId, JumpStartVersionedModelId ]( max_cache_items=max_semantic_version_cache_items, expiration_horizon=semantic_version_cache_expiration_horizon, - retrieval_function=self._get_open_source_manifest_key_from_model_id, + retrieval_function=self._get_open_weight_manifest_key_from_model_id, ) self._proprietary_model_id_manifest_key_cache = LRUCache[ JumpStartVersionedModelId, JumpStartVersionedModelId @@ -125,7 +125,7 @@ def __init__( self._manifest_file_s3_key = manifest_file_s3_key self._proprietary_manifest_s3_key = proprietary_manifest_s3_key self._manifest_file_s3_map = { - JumpStartModelType.OPEN_SOURCE: self._manifest_file_s3_key, + JumpStartModelType.OPEN_WEIGHT: self._manifest_file_s3_key, JumpStartModelType.PROPRIETARY: self._proprietary_manifest_s3_key, } self.s3_bucket_name = ( @@ -152,7 +152,7 @@ def get_region(self) -> str: def set_manifest_file_s3_key( self, key: str, - file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_SOURCE_MANIFEST, + file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, ) -> None: """Set manifest file s3 key, clear cache after new key is set. @@ -160,14 +160,14 @@ def set_manifest_file_s3_key( ValueError: if the file type is not recognized """ file_mapping = { - JumpStartS3FileType.OPEN_SOURCE_MANIFEST: self._manifest_file_s3_key, + JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: self._manifest_file_s3_key, JumpStartS3FileType.PROPRIETARY_MANIFEST: self._proprietary_manifest_s3_key, } property_name = file_mapping.get(file_type) if not property_name: raise ValueError( f"Bad value when setting manifest '{file_type}': must be in" - f" {JumpStartS3FileType.OPEN_SOURCE_MANIFEST}" + f" {JumpStartS3FileType.OPEN_WEIGHT_MANIFEST}" f" {JumpStartS3FileType.PROPRIETARY_MANIFEST}" ) if key != property_name: @@ -175,16 +175,16 @@ def set_manifest_file_s3_key( self.clear() def get_manifest_file_s3_key( - self, file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_SOURCE_MANIFEST + self, file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_WEIGHT_MANIFEST ) -> str: """Return manifest file s3 key for cache.""" - if file_type == JumpStartS3FileType.OPEN_SOURCE_MANIFEST: + if file_type == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: return self._manifest_file_s3_key if file_type == JumpStartS3FileType.PROPRIETARY_MANIFEST: return self._proprietary_manifest_s3_key raise ValueError( f"Bad value when getting manifest '{file_type}':" - f"must be in {JumpStartS3FileType.OPEN_SOURCE_MANIFEST}" + f"must be in {JumpStartS3FileType.OPEN_WEIGHT_MANIFEST}" f"{JumpStartS3FileType.PROPRIETARY_MANIFEST}" ) @@ -278,7 +278,7 @@ def _model_id_retrieval_function( ) other_model_id_version = None - if model_type == JumpStartModelType.OPEN_SOURCE: + if model_type == JumpStartModelType.OPEN_WEIGHT: other_model_id_version = self._select_version( model_id, "*", versions_incompatible_with_sagemaker, model_type ) # all versions here are incompatible with sagemaker @@ -305,14 +305,14 @@ def _model_id_retrieval_function( raise KeyError(error_msg) - def _get_open_source_manifest_key_from_model_id( + def _get_open_weight_manifest_key_from_model_id( self, key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 ) -> JumpStartVersionedModelId: """Retrieve model manifest key for open source model, by filtering supported versions.""" return self._model_id_retrieval_function( - key, value, model_type=JumpStartModelType.OPEN_SOURCE + key, value, model_type=JumpStartModelType.OPEN_WEIGHT ) def _get_proprietary_manifest_key_from_model_id( @@ -369,11 +369,11 @@ def _get_json_file_from_local_override( filetype: JumpStartS3FileType ) -> Union[dict, list]: """Reads json file from local filesystem and returns data.""" - if filetype == JumpStartS3FileType.OPEN_SOURCE_MANIFEST: + if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: metadata_local_root = ( os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE] ) - elif filetype == JumpStartS3FileType.OPEN_SOURCE_SPECS: + elif filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS: metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE] else: raise ValueError(f"Unsupported file type for local override: {filetype}") @@ -402,7 +402,7 @@ def _retrieval_function( file_type, s3_key = key.file_type, key.s3_key if file_type in { - JumpStartS3FileType.OPEN_SOURCE_MANIFEST, + JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, JumpStartS3FileType.PROPRIETARY_MANIFEST, }: if value is not None and not self._is_local_metadata_mode(): @@ -415,7 +415,7 @@ def _retrieval_function( md5_hash=etag, ) if file_type in { - JumpStartS3FileType.OPEN_SOURCE_SPECS, + JumpStartS3FileType.OPEN_WEIGHT_SPECS, JumpStartS3FileType.PROPRIETARY_SPECS, }: formatted_body, _ = self._get_json_file(s3_key, file_type) @@ -424,13 +424,13 @@ def _retrieval_function( return JumpStartCachedS3ContentValue(formatted_content=model_specs) raise ValueError( f"Bad value for key '{key}': must be in" - f"{JumpStartS3FileType.OPEN_SOURCE_MANIFEST, JumpStartS3FileType.OPEN_SOURCE_SPECS}" + f"{JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, JumpStartS3FileType.OPEN_WEIGHT_SPECS}" f"{JumpStartS3FileType.PROPRIETARY_SPECS, JumpStartS3FileType.PROPRIETARY_MANIFEST}" ) def get_manifest( self, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest.""" manifest_dict = self._s3_cache.get( @@ -444,7 +444,7 @@ def get_header( self, model_id: str, semantic_version_str: str, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> JumpStartModelHeader: """Return header for a given JumpStart model ID and semantic version. @@ -461,36 +461,36 @@ def get_header( def _select_version( self, model_id: str, - semantic_version_str: str, + version_str: str, available_versions: List[str], - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> Optional[str]: """Perform semantic version search on available versions. Args: - semantic_version_str (str): the semantic version for which to filter + version_str (str): the semantic version for which to filter available versions. available_versions (List[Version]): list of available versions. """ + if version_str == "*": + if len(available_versions) == 0: + return None + return str(max(available_versions)) + if model_type == JumpStartModelType.PROPRIETARY: - if "*" in semantic_version_str: + if "*" in version_str: raise KeyError( get_wildcard_proprietary_model_version_msg( - model_id, semantic_version_str, available_versions + model_id, version_str, available_versions ) ) - return semantic_version_str if semantic_version_str in available_versions else None - - if semantic_version_str == "*": - if len(available_versions) == 0: - return None - return str(max(available_versions)) + return version_str if version_str in available_versions else None try: - spec = SpecifierSet(f"=={semantic_version_str}") + spec = SpecifierSet(f"=={version_str}") except InvalidSpecifier: - raise KeyError(f"Bad semantic version: {semantic_version_str}") + raise KeyError(f"Bad semantic version: {version_str}") available_versions_filtered = list(spec.filter(available_versions)) return ( str(max(available_versions_filtered)) if available_versions_filtered != [] else None @@ -501,7 +501,7 @@ def _get_header_impl( model_id: str, semantic_version_str: str, attempt: int = 0, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT ) -> JumpStartModelHeader: """Lower-level function to return header. @@ -513,8 +513,8 @@ def _get_header_impl( header. attempt (int): attempt number at retrieving a header. """ - if model_type == JumpStartModelType.OPEN_SOURCE: - versioned_model_id = self._open_source_model_id_manifest_key_cache.get( + if model_type == JumpStartModelType.OPEN_WEIGHT: + versioned_model_id = self._open_weight_model_id_manifest_key_cache.get( JumpStartVersionedModelId(model_id, semantic_version_str) )[0] elif model_type == JumpStartModelType.PROPRIETARY: @@ -540,7 +540,7 @@ def get_specs( self, model_id: str, version_str: str, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT ) -> JumpStartModelSpecs: """Return specs for a given JumpStart model ID and semantic version. @@ -568,5 +568,5 @@ def get_specs( def clear(self) -> None: """Clears the model ID/version and s3 cache.""" self._s3_cache.clear() - self._open_source_model_id_manifest_key_cache.clear() + self._open_weight_model_id_manifest_key_cache.clear() self._proprietary_model_id_manifest_key_cache.clear() diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 114c54963b..7af414327d 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -216,12 +216,12 @@ } MODEL_TYPE_TO_MANIFEST_MAP: Dict[Type[JumpStartModelType], Type[JumpStartS3FileType]] = { - JumpStartModelType.OPEN_SOURCE: JumpStartS3FileType.OPEN_SOURCE_MANIFEST, + JumpStartModelType.OPEN_WEIGHT: JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, JumpStartModelType.PROPRIETARY: JumpStartS3FileType.PROPRIETARY_MANIFEST, } MODEL_TYPE_TO_SPECS_MAP: Dict[Type[JumpStartModelType], Type[JumpStartS3FileType]] = { - JumpStartModelType.OPEN_SOURCE: JumpStartS3FileType.OPEN_SOURCE_SPECS, + JumpStartModelType.OPEN_WEIGHT: JumpStartS3FileType.OPEN_WEIGHT_SPECS, JumpStartModelType.PROPRIETARY: JumpStartS3FileType.PROPRIETARY_SPECS, } diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index 148790d66c..8f73230142 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -41,7 +41,7 @@ class JumpStartModelType(str, Enum): Proprietary model refers to external provider owned Marketplace models. """ - OPEN_SOURCE = "opensource" + OPEN_WEIGHT = "open_weight" PROPRIETARY = "proprietary" diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index d4771d8bc4..a8b8e3525e 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -60,13 +60,12 @@ def get_wildcard_model_version_msg( def get_proprietary_model_subscription_msg( model_id: str, - model_version: str, subscription_link: str, ) -> str: """Returns customer-facing message for using a Marketplace model.""" return ( - f"Using Marketplace model '{model_id}' with version identifier '{model_version}'. " + f"INFO: Using Marketplace model '{model_id}'. " f"Please make sure to subscribe to the model from {subscription_link}" ) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 0faeb92a42..ec99210b54 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -77,7 +77,7 @@ def get_init_kwargs( model_id: str, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 60ed79c9ad..c60db4e3fc 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -71,7 +71,7 @@ def get_default_predictor( tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, sagemaker_session: Session, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. @@ -221,7 +221,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel """ if kwargs.model_type == JumpStartModelType.PROPRIETARY: - kwargs.image_uri = "" + kwargs.image_uri = None return kwargs kwargs.image_uri = kwargs.image_uri or image_uris.retrieve( @@ -536,7 +536,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -703,7 +703,7 @@ def get_init_kwargs( model_id: str, model_from_estimator: bool = False, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, instance_type: Optional[str] = None, @@ -780,14 +780,12 @@ def get_init_kwargs( # we use the model artifact from the training job output if not model_from_estimator: model_init_kwargs = _add_model_data_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_source_dir_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_entry_point_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_env_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_predictor_cls_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_extra_model_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_role_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) diff --git a/src/sagemaker/jumpstart/filters.py b/src/sagemaker/jumpstart/filters.py index b045435ed0..2829593ced 100644 --- a/src/sagemaker/jumpstart/filters.py +++ b/src/sagemaker/jumpstart/filters.py @@ -49,6 +49,7 @@ class SpecialSupportedFilterKeys(str, Enum): TASK = "task" FRAMEWORK = "framework" + MODEL_TYPE = "model_type" FILTER_OPERATOR_STRING_MAPPINGS = { diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index e706771e29..0145a26fe4 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -357,9 +357,7 @@ def log_subscription_warning(self) -> None: sagemaker_session=self.sagemaker_session, ).model_subscription_link JUMPSTART_LOGGER.warning( - get_proprietary_model_subscription_msg( - self.model_id, self.model_version, subscription_link - ) + get_proprietary_model_subscription_msg(self.model_id, subscription_link) ) def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 9318562d2f..35ee6b52c9 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -125,15 +125,11 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: Args: model_id (str): The model ID for which to extract the framework/task/model. - - Raises: - ValueError: If the model ID cannot be parsed into at least 3 components seperated by - "-" character. """ _id_parts = model_id.split("-") if len(_id_parts) < 3: - raise ValueError(f"incorrect model ID: {model_id}.") + return "", "", "" framework = _id_parts[0] task = _id_parts[1] @@ -142,6 +138,20 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: return framework, task, name +def extract_model_type(spec_key: str) -> str: + """Parses model spec key, determine if the model is proprietary or open weight. + + Args: + spek_key (str): The model spec key for which to extract the model type. + """ + model_type = spec_key.split("/")[0] + + if model_type == "proprietary-models": + return JumpStartModelType.PROPRIETARY.value + + return JumpStartModelType.OPEN_WEIGHT.value + + def list_jumpstart_tasks( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: str = JUMPSTART_DEFAULT_REGION_NAME, @@ -247,7 +257,6 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin list_incomplete_models: bool = False, list_old_models: bool = False, list_versions: bool = False, - marketplace_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[Union[Tuple[str], Tuple[str, str]]]: """List models for JumpStart, and optionally apply filters to result. @@ -279,7 +288,6 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin filter=filter, region=region, list_incomplete_models=list_incomplete_models, - marketplace_model=marketplace_model, sagemaker_session=sagemaker_session, ): if model_id not in model_id_version_dict: @@ -306,7 +314,6 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: str = JUMPSTART_DEFAULT_REGION_NAME, list_incomplete_models: bool = False, - marketplace_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Generator: """Generate models for JumpStart, and optionally apply filters to result. @@ -332,21 +339,17 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin s3_client=sagemaker_session.s3_client, model_type=JumpStartModelType.PROPRIETARY, ) - open_source_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + open_weight_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( region=region, s3_client=sagemaker_session.s3_client, - model_type=JumpStartModelType.OPEN_SOURCE, - ) - models_manifest_list = ( - prop_models_manifest_list - if marketplace_model - else (open_source_manifest_list + prop_models_manifest_list) + model_type=JumpStartModelType.OPEN_WEIGHT, ) + models_manifest_list = open_weight_manifest_list + prop_models_manifest_list if isinstance(filter, str): filter = Identity(filter) - manifest_keys = set(models_manifest_list[0].__slots__) + manifest_keys = set(models_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__) all_keys: Set[str] = set() @@ -369,6 +372,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys + is_model_type_filter = SpecialSupportedFilterKeys.MODEL_TYPE in all_keys def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, str]]: @@ -391,6 +395,11 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, SpecialSupportedFilterKeys.FRAMEWORK ] = extract_framework_task_model(model_manifest.model_id)[0] + if is_model_type_filter: + manifest_specs_cached_values[ + SpecialSupportedFilterKeys.MODEL_TYPE + ] = extract_model_type(model_manifest.spec_key) + if Version(model_manifest.min_version) > Version(get_sagemaker_version()): return None diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 47a7e0e4e4..045187551d 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -103,8 +103,8 @@ def __repr__(self) -> str: class JumpStartS3FileType(str, Enum): """Type of files published in JumpStart S3 distribution buckets.""" - OPEN_SOURCE_MANIFEST = "manifest" - OPEN_SOURCE_SPECS = "specs" + OPEN_WEIGHT_MANIFEST = "manifest" + OPEN_WEIGHT_SPECS = "specs" PROPRIETARY_MANIFEST = "proptietary_manifest" PROPRIETARY_SPECS = "proprietary_specs" @@ -1099,7 +1099,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -1206,7 +1206,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -1340,7 +1340,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, region: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, role: Optional[str] = None, @@ -1485,7 +1485,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, region: Optional[str] = None, inputs: Optional[Union[str, Dict, Any, Any]] = None, wait: Optional[bool] = None, diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index a9e2555ec7..17547b209a 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -540,7 +540,7 @@ def verify_model_region_and_return_specs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_SOURCE, + model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHT, ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -759,12 +759,12 @@ def validate_model_id_and_get_type( def _get_model_type( model_id: str, - open_source_models: Set[str], + open_weight_models: Set[str], proprietary_models: Set[str], script: enums.JumpStartScriptScope, ) -> Optional[enums.JumpStartModelType]: - if model_id in open_source_models: - return enums.JumpStartModelType.OPEN_SOURCE + if model_id in open_weight_models: + return enums.JumpStartModelType.OPEN_WEIGHT if model_id in proprietary_models: if script == enums.JumpStartScriptScope.INFERENCE: return enums.JumpStartModelType.PROPRIETARY @@ -780,16 +780,16 @@ def _get_model_type( region = region or constants.JUMPSTART_DEFAULT_REGION_NAME model_version = model_version or "*" models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( - region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.OPEN_SOURCE + region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.OPEN_WEIGHT ) - open_source_model_id_set = {model.model_id for model in models_manifest_list} + open_weight_model_id_set = {model.model_id for model in models_manifest_list} proprietary_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.PROPRIETARY ) proprietary_model_id_set = {model.model_id for model in proprietary_manifest_list} - return _get_model_type(model_id, open_source_model_id_set, proprietary_model_id_set, script) + return _get_model_type(model_id, open_weight_model_id_set, proprietary_model_id_set, script) def get_jumpstart_model_id_version_from_resource_arn( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index ff340b58e9..92e184c7bb 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -138,7 +138,7 @@ class Model(ModelBase, InferenceRecommenderMixin): def __init__( self, - image_uri: Union[str, PipelineVariable], + image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, predictor_cls: Optional[callable] = None, diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index 89286e217c..4d35a62632 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -32,7 +32,7 @@ def retrieve_all_examples( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, serialize: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -123,7 +123,7 @@ def retrieve_example( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_SOURCE, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, serialize: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index a1e5996bd0..4d564c2afc 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -42,7 +42,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> Predictor: """Retrieves the default predictor for the model matching the given arguments. diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index a560d0075d..342770a410 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -34,7 +34,7 @@ def retrieve_default( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, ) -> ResourceRequirements: diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index ca92ff1b53..bcd34f722e 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -94,7 +94,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> BaseSerializer: """Retrieves the default serializer for the model matching the given arguments. diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 9f56a67b68..2024a487ab 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -17,9 +17,11 @@ import pytest from sagemaker.enums import EndpointType +from sagemaker.jumpstart.filters import And from sagemaker.predictor import retrieve_default import tests.integ +from sagemaker.jumpstart import notebook_utils from sagemaker.jumpstart.model import JumpStartModel from tests.integ.sagemaker.jumpstart.constants import ( @@ -273,7 +275,7 @@ def test_proprietary_jumpstart_model(setup): model = JumpStartModel( model_id=model_id, - model_version="2.0.004", + model_version="*", role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), ) @@ -286,17 +288,3 @@ def test_proprietary_jumpstart_model(setup): response = predictor.predict(payload) assert response is not None - - -def test_jumpstart_payload(setup): - model_id = "ai21-jurassic-2-light" - - model = JumpStartModel( - model_id=model_id, - model_version="2.0.004", - role=get_sm_session().get_caller_identity_arn(), - sagemaker_session=get_sm_session(), - ) - - resp = model.retrieve_example_payload() - print(resp) \ No newline at end of file diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 4272684d35..1c12f777a0 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -37,7 +37,7 @@ def test_jumpstart_default_accept_types( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -55,7 +55,7 @@ def test_jumpstart_default_accept_types( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) @@ -70,7 +70,7 @@ def test_jumpstart_supported_accept_types( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -91,5 +91,5 @@ def test_jumpstart_supported_accept_types( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index 50250d8f3b..ae698a7d94 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -36,7 +36,7 @@ def test_jumpstart_default_content_types( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -54,7 +54,7 @@ def test_jumpstart_default_content_types( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) @@ -69,7 +69,7 @@ def test_jumpstart_supported_content_types( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -89,5 +89,5 @@ def test_jumpstart_supported_content_types( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index 3e917d55c7..4b06ac8c4e 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -38,7 +38,7 @@ def test_jumpstart_default_deserializers( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -56,7 +56,7 @@ def test_jumpstart_default_deserializers( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) @@ -71,7 +71,7 @@ def test_jumpstart_deserializer_options( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -96,5 +96,5 @@ def test_jumpstart_deserializer_options( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index f1a5176263..2ee9ecec38 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -34,7 +34,7 @@ def test_jumpstart_default_environment_variables( ): patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -58,7 +58,7 @@ def test_jumpstart_default_environment_variables( model_id=model_id, version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() @@ -82,7 +82,7 @@ def test_jumpstart_default_environment_variables( model_id=model_id, version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() @@ -119,7 +119,7 @@ def test_jumpstart_sdk_environment_variables( ): patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -144,7 +144,7 @@ def test_jumpstart_sdk_environment_variables( model_id=model_id, version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() @@ -169,7 +169,7 @@ def test_jumpstart_sdk_environment_variables( model_id=model_id, version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index 3ac10c9109..babae7f86c 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -34,7 +34,7 @@ def test_jumpstart_default_hyperparameters( ): patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -52,7 +52,7 @@ def test_jumpstart_default_hyperparameters( model_id=model_id, version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() @@ -70,7 +70,7 @@ def test_jumpstart_default_hyperparameters( model_id=model_id, version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() @@ -96,7 +96,7 @@ def test_jumpstart_default_hyperparameters( model_id=model_id, version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index cf7a321b79..f3ef886621 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -112,7 +112,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): return spec patched_get_model_specs.side_effect = add_options_to_hyperparameter - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -144,7 +144,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() @@ -424,7 +424,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): return spec patched_get_model_specs.side_effect = add_options_to_hyperparameter - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -450,7 +450,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() @@ -484,7 +484,7 @@ def test_jumpstart_validate_all_hyperparameters( ): patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -512,7 +512,7 @@ def test_jumpstart_validate_all_hyperparameters( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 912ef6fb49..1171161be0 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -35,7 +35,7 @@ def test_jumpstart_common_image_uri( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -54,7 +54,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -75,7 +75,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -96,7 +96,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -117,7 +117,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index 748cfca0e2..298a1e9eb2 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -28,7 +28,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_model_id_and_get_type): patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, model_version = "huggingface-eqa-bert-base-cased", "*" region = "us-west-2" @@ -50,7 +50,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() @@ -69,7 +69,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() @@ -94,7 +94,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() @@ -121,7 +121,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index c55f1779cd..1ca13cfed9 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -84,7 +84,7 @@ def test_non_prepacked( mock_jumpstart_model_factory_logger: mock.Mock, mock_jumpstart_estimator_factory_logger: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_sagemaker_timestamp.return_value = "9876" @@ -94,7 +94,7 @@ def test_non_prepacked( mock_get_model_specs.side_effect = get_special_model_spec - mock_get_model_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_get_model_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session @@ -207,7 +207,7 @@ def test_prepacked( ): mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model-prepacked", "*" @@ -310,7 +310,7 @@ def test_gated_model_s3_uri( mock_timestamp.return_value = "8675309" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-gated-artifact-trainable-model", "*" @@ -448,7 +448,7 @@ def test_gated_model_non_model_package_s3_uri( mock_get_jumpstart_gated_content_bucket.return_value = "top-secret-private-models-bucket" mock_timestamp.return_value = "8675309" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-gated-artifact-non-model-package-trainable-model", "*" @@ -601,7 +601,7 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT with pytest.raises(ValueError) as e: JumpStartEstimator(model_id=model_id, region="eu-north-1") @@ -628,7 +628,7 @@ def test_deprecated( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "deprecated_model", "*" @@ -661,7 +661,7 @@ def test_vulnerable( mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "vulnerable_model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -810,7 +810,7 @@ def evaluate_estimator_workflow_with_kwargs( mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, model_version = "js-trainable-model", "*" @@ -921,7 +921,7 @@ def test_jumpstart_estimator_tags_disabled( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model-prepacked", "*" @@ -960,7 +960,7 @@ def test_jumpstart_estimator_tags( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model-prepacked", "*" @@ -1004,7 +1004,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( mock_attach: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT get_model_id_version_from_training_job.return_value = ( "js-trainable-model-prepacked", @@ -1047,7 +1047,7 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( mock_attach: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT get_model_id_version_from_training_job.side_effect = ValueError() @@ -1126,7 +1126,7 @@ def test_validate_model_id_and_get_type( mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT JumpStartEstimator(model_id="valid_model_id") mock_validate_model_id_and_get_type.return_value = False @@ -1158,7 +1158,7 @@ def test_no_predictor_returns_default_predictor( mock_get_default_predictor.return_value = default_predictor_with_presets - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model-prepacked", "*" @@ -1217,7 +1217,7 @@ def test_no_predictor_yes_async_inference_config( mock_get_default_predictor.return_value = default_predictor_with_presets - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model-prepacked", "*" @@ -1267,7 +1267,7 @@ def test_yes_predictor_returns_unmodified_predictor( mock_get_default_predictor.return_value = default_predictor_with_presets - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model-prepacked", "*" @@ -1316,7 +1316,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( mock_supports_incremental_training: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_estimator_deploy.return_value = default_predictor @@ -1370,7 +1370,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( mock_supports_incremental_training: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_estimator_deploy.return_value = default_predictor @@ -1419,7 +1419,7 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_sagemaker_timestamp.return_value = "3456" @@ -1480,7 +1480,7 @@ def test_training_passes_role_to_deploy( mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_sagemaker_timestamp.return_value = "3456" @@ -1560,7 +1560,7 @@ def test_training_passes_session_to_deploy( mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_sagemaker_timestamp.return_value = "3456" @@ -1711,7 +1711,7 @@ def test_model_artifact_variant_estimator( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "model-artifact-variant-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py index 1535039570..90f0472924 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py @@ -104,7 +104,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" @@ -161,7 +161,7 @@ def test_without_arg_overwrites_with_kwarg_collisions_with_config( mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" @@ -234,7 +234,7 @@ def test_with_arg_overwrites_with_kwarg_collisions_with_config( mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" @@ -316,7 +316,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_with_config( mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" @@ -393,7 +393,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" @@ -454,7 +454,7 @@ def test_without_arg_overwrites_with_kwarg_collisions_without_config( mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_get_caller_identity_arn.return_value = execution_role model_id, _ = "js-trainable-model", "*" @@ -525,7 +525,7 @@ def test_with_arg_overwrites_with_kwarg_collisions_without_config( ): mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" @@ -599,7 +599,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( ): mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index e51545adb8..b6fe735dff 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -75,7 +75,7 @@ def test_non_prepacked( mock_sagemaker_timestamp.return_value = "7777" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -149,7 +149,7 @@ def test_non_prepacked_inference_component_based_endpoint( mock_sagemaker_timestamp.return_value = "7777" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = ( @@ -229,7 +229,7 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom mock_sagemaker_timestamp.return_value = "7777" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-model-class-model-prepacked", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -300,7 +300,7 @@ def test_prepacked( ): mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-model-class-model-prepacked", "*" @@ -364,7 +364,7 @@ def test_no_compiled_model_warning_log_js_models( mock_timestamp.return_value = "1234" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "gated_llama_neuron_model", "*" @@ -398,7 +398,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( ): mock_timestamp.return_value = "1234" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "gated_variant-model", "*" @@ -477,11 +477,10 @@ def test_proprietary_model_endpoint( model = JumpStartModel(model_id=model_id, model_version="2.0.004") mock_model_init.assert_called_once_with( - image_uri="", predictor_cls=Predictor, role=execution_role, sagemaker_session=sagemaker_session, - enable_network_isolation=False, + enable_network_isolation=True, ) model.deploy() @@ -514,7 +513,7 @@ def test_deprecated( ): mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "deprecated_model", "*" @@ -539,7 +538,7 @@ def test_vulnerable( mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_model_deploy.return_value = default_predictor @@ -624,7 +623,7 @@ def evaluate_model_workflow_with_kwargs( mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_session.return_value = sagemaker_session @@ -727,7 +726,7 @@ def test_validate_model_id_and_get_type( mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT JumpStartModel(model_id="valid_model_id") mock_validate_model_id_and_get_type.return_value = False @@ -754,7 +753,7 @@ def test_no_predictor_returns_default_predictor( mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-model-class-model-prepacked", "*" @@ -776,7 +775,7 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -801,7 +800,7 @@ def test_no_predictor_yes_async_inference_config( mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-model-class-model-prepacked", "*" @@ -837,7 +836,7 @@ def test_yes_predictor_returns_default_predictor( mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-model-class-model-prepacked", "*" @@ -910,7 +909,7 @@ def test_model_id_not_found_refeshes_cache_inference( mock_validate_model_id_and_get_type.side_effect = [ False, - JumpStartModelType.OPEN_SOURCE, + JumpStartModelType.OPEN_WEIGHT, ] JumpStartModel( model_id=model_id, @@ -945,7 +944,7 @@ def test_jumpstart_model_tags( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "env-var-variant-model", "*" @@ -981,7 +980,7 @@ def test_jumpstart_model_tags_disabled( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "env-var-variant-model", "*" @@ -1013,7 +1012,7 @@ def test_jumpstart_model_package_arn( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-model-package-arn", "*" @@ -1047,7 +1046,7 @@ def test_jumpstart_model_package_arn_override( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT # arbitrary model without model packarn arn model_id, _ = "js-trainable-model", "*" @@ -1091,7 +1090,7 @@ def test_jumpstart_model_package_arn_unsupported_region( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-model-package-arn", "*" @@ -1128,7 +1127,7 @@ def test_model_data_s3_prefix_override( mock_sagemaker_timestamp.return_value = "7777" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1190,7 +1189,7 @@ def test_model_data_s3_prefix_model( ): mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "model_data_s3_prefix_model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1234,7 +1233,7 @@ def test_model_artifact_variant_model( ): mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "model-artifact-variant-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1297,7 +1296,7 @@ def test_model_registry_accept_and_response_types( ): mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "model_data_s3_prefix_model", "*" mock_get_model_specs.side_effect = get_special_model_spec diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py index 02d170e54b..e32d063655 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py @@ -79,7 +79,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" mock_retrieve_kwargs.return_value = {} @@ -120,7 +120,7 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" @@ -166,7 +166,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" @@ -212,7 +212,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" @@ -260,7 +260,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" @@ -305,7 +305,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" @@ -353,7 +353,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" @@ -394,7 +394,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, _ = "js-trainable-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index 173fa923a4..a72544464c 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -337,7 +337,7 @@ def test_retrieve_model_package_arn( self, patched_get_model_specs: Mock, patched_validate_model_id_and_get_type: Mock ): patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id = "variant-model" region = "us-west-2" @@ -447,7 +447,7 @@ def test_retrieve_uri_from_gated_bucket( self, patched_get_model_specs, patched_validate_model_id_and_get_type ): patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id = "private-model" region = "us-west-2" diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index e3adef0c26..416424806e 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -177,11 +177,11 @@ def test_jumpstart_cache_get_header(): with pytest.raises(KeyError) as e: cache.get_header( model_id="ai21-summarization", - semantic_version_str="*", + semantic_version_str="3.*", model_type=JumpStartModelType.PROPRIETARY, ) assert ( - "Marketplace model 'ai21-summarization' does not support wildcard version identifier '*'. " + "Marketplace model 'ai21-summarization' does not support wildcard version identifier '3.*'. " "You can pin to version '1.1.003'. " "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " "for list of supported model IDs. " in str(e.value) @@ -275,21 +275,21 @@ def test_jumpstart_cache_get_header(): with pytest.raises(KeyError): cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", - semantic_version_str="*", + semantic_version_str="1.1.004", model_type=JumpStartModelType.PROPRIETARY, ) with pytest.raises(KeyError): cache.get_header( - model_id="tensorflow-ic-imagenet-inception-v3-classification-4", - semantic_version_str="1.1.004", + model_id="ai21-summarization", + semantic_version_str="2.*", model_type=JumpStartModelType.PROPRIETARY, ) with pytest.raises(KeyError): cache.get_header( model_id="ai21-summarization", - semantic_version_str="2.*", + semantic_version_str="v*", model_type=JumpStartModelType.PROPRIETARY, ) @@ -346,7 +346,7 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client): cache.clear.assert_called_once() cache.clear.reset_mock() - cache.set_manifest_file_s3_key("some_key1", file_type=JumpStartS3FileType.OPEN_SOURCE_MANIFEST) + cache.set_manifest_file_s3_key("some_key1", file_type=JumpStartS3FileType.OPEN_WEIGHT_MANIFEST) cache.clear.assert_called_once() with pytest.raises(ValueError): cache.set_manifest_file_s3_key("some_key1", file_type="unknown_type") @@ -498,11 +498,11 @@ def test_jumpstart_cache_accepts_input_parameters(): assert cache._s3_cache._max_cache_items == max_s3_cache_items assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon assert ( - cache._open_source_model_id_manifest_key_cache._max_cache_items + cache._open_weight_model_id_manifest_key_cache._max_cache_items == max_semantic_version_cache_items ) assert ( - cache._open_source_model_id_manifest_key_cache._expiration_horizon + cache._open_weight_model_id_manifest_key_cache._expiration_horizon == semantic_version_cache_expiration_horizon ) @@ -822,8 +822,8 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") cache.clear = MagicMock() - cache._open_source_model_id_manifest_key_cache = MagicMock() - cache._open_source_model_id_manifest_key_cache.get.side_effect = [ + cache._open_weight_model_id_manifest_key_cache = MagicMock() + cache._open_weight_model_id_manifest_key_cache.get.side_effect = [ ( JumpStartVersionedModelId( "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" @@ -852,7 +852,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache.clear.assert_called_once() cache.clear.reset_mock() - cache._open_source_model_id_manifest_key_cache.get.side_effect = [ + cache._open_weight_model_id_manifest_key_cache.get.side_effect = [ ( JumpStartVersionedModelId( "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" @@ -937,11 +937,11 @@ def test_jumpstart_cache_get_specs(): with pytest.raises(KeyError) as e: cache.get_specs( model_id="ai21-summarization", - version_str="*", + version_str="3.*", model_type=JumpStartModelType.PROPRIETARY, ) assert ( - "Marketplace model 'ai21-summarization' does not support wildcard version identifier '*'. " + "Marketplace model 'ai21-summarization' does not support wildcard version identifier '3.*'. " "You can pin to version '1.1.003'. " "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " "for list of supported model IDs. " in str(e.value) diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 168b69d704..02b656497b 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -424,7 +424,7 @@ def test_list_jumpstart_models_region( patched_get_manifest.assert_called_with( region="some-region", s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -628,7 +628,7 @@ def test_list_jumpstart_proprietary_models( "lighton-mini-instruct40b", ] - all_open_source_model_ids = [ + all_open_weight_model_ids = [ "catboost-classification-model", "huggingface-spc-bert-base-cased", "lightgbm-classification-model", @@ -639,10 +639,11 @@ def test_list_jumpstart_proprietary_models( "xgboost-classification-model", ] - assert list_jumpstart_models(marketplace_model=True) == all_prop_model_ids + assert list_jumpstart_models("model_type == proprietary") == all_prop_model_ids + assert list_jumpstart_models("model_type == open_weight") == all_open_weight_model_ids assert list_jumpstart_models(list_versions=False) == sorted( - all_prop_model_ids + all_open_source_model_ids + all_prop_model_ids + all_open_weight_model_ids ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -718,7 +719,7 @@ def test_get_model_url( ): patched_get_model_specs.side_effect = get_prototype_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( region ) @@ -747,5 +748,5 @@ def test_get_model_url( version=version, region="us-west-2", s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 336025c448..51148021c1 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -127,7 +127,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=mock_session, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index cadc54c96d..cfbc4d82c9 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -48,7 +48,7 @@ def get_header_from_base_header( model_id: str = None, semantic_version_str: str = None, version: str = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> JumpStartModelHeader: if version and semantic_version_str: @@ -88,7 +88,7 @@ def get_header_from_base_header( def get_prototype_manifest( region: str = JUMPSTART_DEFAULT_REGION_NAME, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> List[JumpStartModelHeader]: if model_type == JumpStartModelType.PROPRIETARY: return [JumpStartModelHeader(spec) for spec in BASE_PROPRIETARY_MANIFEST] @@ -104,7 +104,7 @@ def get_prototype_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. @@ -120,7 +120,7 @@ def get_special_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. This is reserved @@ -136,7 +136,7 @@ def get_special_model_spec_for_inference_component_based_endpoint( model_id: str = None, version: str = None, s3_client: boto3.client = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID and adding @@ -160,7 +160,7 @@ def get_spec_from_base_spec( version_str: str = None, version: str = None, s3_client: boto3.client = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_SOURCE, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, ) -> JumpStartModelSpecs: if version and version_str: @@ -208,13 +208,13 @@ def patched_retrieval_function( ) -> JumpStartCachedS3ContentValue: filetype, s3_key = key.file_type, key.s3_key - if filetype == JumpStartS3FileType.OPEN_SOURCE_MANIFEST: + if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: return JumpStartCachedS3ContentValue( formatted_content=get_formatted_manifest(BASE_MANIFEST) ) - if filetype == JumpStartS3FileType.OPEN_SOURCE_SPECS: + if filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS: _, model_id, specs_version = s3_key.split("/") version = specs_version.replace("specs_v", "").replace(".json", "") return JumpStartCachedS3ContentValue( diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index 19927c9b16..ecfaec7214 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -33,7 +33,7 @@ def test_jumpstart_default_metric_definitions( ): patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -56,7 +56,7 @@ def test_jumpstart_default_metric_definitions( model_id=model_id, version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() @@ -76,7 +76,7 @@ def test_jumpstart_default_metric_definitions( model_id=model_id, version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index f1d5441a71..2187247702 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -36,7 +36,7 @@ def test_jumpstart_common_model_uri( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -52,7 +52,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -70,7 +70,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -89,7 +89,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -108,7 +108,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 37f891e770..c0369a595b 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -34,7 +34,7 @@ def test_jumpstart_resource_requirements( ): patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT region = "us-west-2" mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -55,7 +55,7 @@ def test_jumpstart_resource_requirements( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() @@ -116,7 +116,7 @@ def test_jumpstart_no_supported_resource_requirements( ): patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, model_version = "no-supported-instance-types-model", "*" region = "us-west-2" @@ -137,7 +137,7 @@ def test_jumpstart_no_supported_resource_requirements( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 0dfe677936..7b3ad26e15 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -36,7 +36,7 @@ def test_jumpstart_common_script_uri( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -52,7 +52,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -70,7 +70,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -89,7 +89,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -108,7 +108,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index 8eeb867d71..10d09c973c 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -35,7 +35,7 @@ def test_jumpstart_default_serializers( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -55,7 +55,7 @@ def test_jumpstart_default_serializers( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) patched_get_model_specs.reset_mock() @@ -72,7 +72,7 @@ def test_jumpstart_serializer_options( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_SOURCE + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -99,5 +99,5 @@ def test_jumpstart_serializer_options( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_SOURCE, + model_type=JumpStartModelType.OPEN_WEIGHT, ) From 27e14b9ef4598e4b0853019a9385e6aaf6c3e059 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Fri, 8 Mar 2024 15:05:32 +0000 Subject: [PATCH 21/30] update filter name --- src/sagemaker/jumpstart/constants.py | 3 +++ src/sagemaker/jumpstart/notebook_utils.py | 10 ++++++---- .../sagemaker/jumpstart/model/test_jumpstart_model.py | 1 - 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 7af414327d..7eb21073e6 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -190,6 +190,9 @@ SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY = "SageMakerGatedModelS3Uri" +PROPRIETARY_MODEL_SPEC_PREFIX = "proprietary-models" +PROPRIETARY_MODEL_FILTER_NAME = "marketplace" + CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP: Dict[MIMEType, SerializerType] = { MIMEType.X_IMAGE: SerializerType.RAW_BYTES, MIMEType.LIST_TEXT: SerializerType.JSON, diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 35ee6b52c9..cd75a717c2 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -24,6 +24,8 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, + PROPRIETARY_MODEL_SPEC_PREFIX, + PROPRIETARY_MODEL_FILTER_NAME, ) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.filters import ( @@ -128,7 +130,7 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: """ _id_parts = model_id.split("-") - if len(_id_parts) < 3: + if len(_id_parts) != 3: return "", "", "" framework = _id_parts[0] @@ -144,10 +146,10 @@ def extract_model_type(spec_key: str) -> str: Args: spek_key (str): The model spec key for which to extract the model type. """ - model_type = spec_key.split("/")[0] + model_spec_prefix = spec_key.split("/")[0] - if model_type == "proprietary-models": - return JumpStartModelType.PROPRIETARY.value + if model_spec_prefix == PROPRIETARY_MODEL_SPEC_PREFIX: + return PROPRIETARY_MODEL_FILTER_NAME return JumpStartModelType.OPEN_WEIGHT.value diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 2024a487ab..e8e319c63d 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -21,7 +21,6 @@ from sagemaker.predictor import retrieve_default import tests.integ -from sagemaker.jumpstart import notebook_utils from sagemaker.jumpstart.model import JumpStartModel from tests.integ.sagemaker.jumpstart.constants import ( From ec816c3c647d4d1ffe7dde40933abac304617b7e Mon Sep 17 00:00:00 2001 From: Haotian An Date: Fri, 8 Mar 2024 15:23:04 +0000 Subject: [PATCH 22/30] doc update --- doc/doc_utils/jumpstart_doc_utils.py | 4 ++-- tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index af078888b7..03ed244440 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -227,9 +227,9 @@ def create_marketplace_model_table(): if model["model_id"] not in sdk_manifest_top_versions_for_models: sdk_manifest_top_versions_for_models[model["model_id"]] = model else: - if Version( + if str( sdk_manifest_top_versions_for_models[model["model_id"]]["version"] - ) < Version(model["version"]): + ) < str(model["version"]): sdk_manifest_top_versions_for_models[model["model_id"]] = model marketplace_content_entries = [] diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index e8e319c63d..3e60051529 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -17,7 +17,6 @@ import pytest from sagemaker.enums import EndpointType -from sagemaker.jumpstart.filters import And from sagemaker.predictor import retrieve_default import tests.integ From 07fa93ef936f9633e5490e963af74a549562ccba Mon Sep 17 00:00:00 2001 From: Haotian An Date: Fri, 8 Mar 2024 15:45:59 +0000 Subject: [PATCH 23/30] fix: black --- doc/doc_utils/jumpstart_doc_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 03ed244440..1ef4252c7b 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -227,9 +227,9 @@ def create_marketplace_model_table(): if model["model_id"] not in sdk_manifest_top_versions_for_models: sdk_manifest_top_versions_for_models[model["model_id"]] = model else: - if str( - sdk_manifest_top_versions_for_models[model["model_id"]]["version"] - ) < str(model["version"]): + if str(sdk_manifest_top_versions_for_models[model["model_id"]]["version"]) < str( + model["version"] + ): sdk_manifest_top_versions_for_models[model["model_id"]] = model marketplace_content_entries = [] From bceb17b1685671705ff653ca2e8f577d56944321 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Fri, 8 Mar 2024 21:11:03 +0000 Subject: [PATCH 24/30] rename to proprietary model and fix unittests --- doc/doc_utils/jumpstart_doc_utils.py | 22 +++++++++---------- src/sagemaker/jumpstart/enums.py | 2 -- src/sagemaker/jumpstart/exceptions.py | 6 ++--- src/sagemaker/jumpstart/factory/model.py | 2 +- src/sagemaker/jumpstart/filters.py | 16 ++++++++++++++ src/sagemaker/jumpstart/notebook_utils.py | 13 ++++++----- src/sagemaker/jumpstart/utils.py | 2 +- src/sagemaker/payloads.py | 4 ++-- .../sagemaker/jumpstart/model/test_model.py | 2 +- tests/unit/sagemaker/jumpstart/test_cache.py | 4 ++-- .../jumpstart/test_notebook_utils.py | 1 + 11 files changed, 46 insertions(+), 28 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 1ef4252c7b..917cb9c283 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -206,10 +206,10 @@ def get_model_source(url): return "Source" -def create_marketplace_model_table(): +def create_proprietary_model_table(): marketpkace_content_intro = [] marketpkace_content_intro.append("\n") - marketpkace_content_intro.append(".. list-table:: Available Models\n") + marketpkace_content_intro.append(".. list-table:: Available Proprietary Models\n") marketpkace_content_intro.append(" :widths: 50 20 20 20 20\n") marketpkace_content_intro.append(" :header-rows: 1\n") marketpkace_content_intro.append(" :class: datatable\n") @@ -232,17 +232,17 @@ def create_marketplace_model_table(): ): sdk_manifest_top_versions_for_models[model["model_id"]] = model - marketplace_content_entries = [] + proprietary_content_entries = [] for model in sdk_manifest_top_versions_for_models.values(): model_spec = get_jumpstart_sdk_spec(model["spec_key"]) - marketplace_content_entries.append(" * - {}\n".format(model_spec["model_id"])) - marketplace_content_entries.append(" - {}\n".format(False)) # TODO: support training - marketplace_content_entries.append(" - {}\n".format(model["version"])) - marketplace_content_entries.append(" - {}\n".format(model["min_version"])) - marketplace_content_entries.append( + proprietary_content_entries.append(" * - {}\n".format(model_spec["model_id"])) + proprietary_content_entries.append(" - {}\n".format(False)) # TODO: support training + proprietary_content_entries.append(" - {}\n".format(model["version"])) + proprietary_content_entries.append(" - {}\n".format(model["min_version"])) + proprietary_content_entries.append( " - `{} <{}>`__ |external-link|\n".format("Source", model_spec.get("url")) ) - return marketpkace_content_intro + marketplace_content_entries + ["\n"] + return marketpkace_content_intro + proprietary_content_entries + ["\n"] def create_jumpstart_model_table(): @@ -348,10 +348,10 @@ def create_jumpstart_model_table(): f.writelines(file_content_single_entry) f.close() - marketplace_content_entries = create_marketplace_model_table() + proprietary_content_entries = create_proprietary_model_table() f = open("doc_utils/pretrainedmodels.rst", "a") f.writelines(file_content_intro) f.writelines(open_weight_content_entries) - f.writelines(marketplace_content_entries) + f.writelines(proprietary_content_entries) f.close() diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index 8f73230142..b28f202fe7 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -91,8 +91,6 @@ class JumpStartTag(str, Enum): MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" - MARKETPLACE_MODEL_TYPE_VALUE = "SageMakerJumpStartMarketplace" - class SerializerType(str, Enum): """Enum class for serializers associated with JumpStart models.""" diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index a8b8e3525e..d98e98b9ab 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -62,10 +62,10 @@ def get_proprietary_model_subscription_msg( model_id: str, subscription_link: str, ) -> str: - """Returns customer-facing message for using a Marketplace model.""" + """Returns customer-facing message for using a proprietary model.""" return ( - f"INFO: Using Marketplace model '{model_id}'. " + f"INFO: Using proprietary model '{model_id}'. " f"Please make sure to subscribe to the model from {subscription_link}" ) @@ -75,7 +75,7 @@ def get_wildcard_proprietary_model_version_msg( ) -> str: """Returns customer-facing message for passing wildcard version to proprietary models.""" msg = ( - f"Marketplace model '{model_id}' does not support " + f"Proprietary model '{model_id}' does not support " f"wildcard version identifier '{wildcard_model_version}'. " ) if len(available_versions) > 0: diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index c60db4e3fc..d26d2bded0 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -169,7 +169,7 @@ def _log_model_type(kwargs: JumpStartModelInitKwargs) -> None: """Log the model type being used""" if kwargs.model_type == JumpStartModelType.PROPRIETARY: JUMPSTART_LOGGER.info( - "Marketplace model %s of version %s is being used.", + "Proprietary model %s of version %s is being used.", kwargs.model_id, kwargs.model_version, ) diff --git a/src/sagemaker/jumpstart/filters.py b/src/sagemaker/jumpstart/filters.py index 2829593ced..220a1bc9a2 100644 --- a/src/sagemaker/jumpstart/filters.py +++ b/src/sagemaker/jumpstart/filters.py @@ -430,6 +430,22 @@ def __init__(self, key: str, value: str, operator: str): self.value = value self.operator = operator + def set_key(self, key: str) -> None: + """Sets the key for the model filter. + + Args: + key (str): The key to be set. + """ + self.key = key + + def set_value(self, value: str) -> None: + """Sets the value for the model filter. + + Args: + value (str): The value to be set. + """ + self.value = value + def parse_filter_string(filter_string: str) -> ModelFilter: """Parse filter string and return a serialized ``ModelFilter`` object. diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index cd75a717c2..62806de6e6 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -25,7 +25,6 @@ DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, PROPRIETARY_MODEL_SPEC_PREFIX, - PROPRIETARY_MODEL_FILTER_NAME, ) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.filters import ( @@ -130,7 +129,7 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: """ _id_parts = model_id.split("-") - if len(_id_parts) != 3: + if len(_id_parts) < 3: return "", "", "" framework = _id_parts[0] @@ -149,7 +148,7 @@ def extract_model_type(spec_key: str) -> str: model_spec_prefix = spec_key.split("/")[0] if model_spec_prefix == PROPRIETARY_MODEL_SPEC_PREFIX: - return PROPRIETARY_MODEL_FILTER_NAME + return JumpStartModelType.PROPRIETARY.value return JumpStartModelType.OPEN_WEIGHT.value @@ -228,6 +227,7 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + if (isinstance(filter, Constant) and filter.resolved_value == BooleanValues.TRUE) or ( isinstance(filter, str) and filter.lower() == BooleanValues.TRUE.lower() ): @@ -279,8 +279,6 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin versions should be included in the returned result. (Default: False). list_versions (bool): Optional. True if versions for models should be returned in addition to the id of the model. (Default: False). - marketplace_models (bool): Optional. True if only listing JumpStart Marketplace models. - (Default: False). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ @@ -361,6 +359,11 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin model_filter = operator.unresolved_value key = model_filter.key all_keys.add(key) + if model_filter.key == SpecialSupportedFilterKeys.MODEL_TYPE and model_filter.value in [ + "marketplace", + "proprietary", + ]: + model_filter.set_value(JumpStartModelType.PROPRIETARY.value) model_filters.add(model_filter) for key in all_keys: diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 17547b209a..3b1b101c29 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -369,7 +369,7 @@ def add_jumpstart_model_id_version_tags( ) if model_type == enums.JumpStartModelType.PROPRIETARY: tags = add_single_jumpstart_tag( - enums.JumpStartTag.MARKETPLACE_MODEL_TYPE_VALUE, + enums.JumpStartModelType.PROPRIETARY.value, enums.JumpStartTag.MODEL_TYPE, tags, is_uri=False, diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index 4d35a62632..d21ca32480 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -137,8 +137,8 @@ def retrieve_example( the model payload. model_version (str): The version of the JumpStart model for which to retrieve the model payload. - model_type (str): The model type of the JumpStart model, either is open source - or marketplace (proprietary). + model_type (str): The model type of the JumpStart model, either is open weight + or proprietary. serialize (bool): Whether to serialize byte-stream valued payloads by downloading binary files from s3 and applying encoding, or to keep payload in pre-serialized state. Set this option to False if you want to avoid s3 downloads or if you diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index b6fe735dff..3d7c7cac7d 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -492,7 +492,7 @@ def test_proprietary_model_endpoint( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "ai21-summarization"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "2.0.004"}, - {"Key": JumpStartTag.MODEL_TYPE, "Value": "SageMakerJumpStartMarketplace"}, + {"Key": JumpStartTag.MODEL_TYPE, "Value": "proprietary"}, ], endpoint_logging=False, model_data_download_timeout=3600, diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 416424806e..ed660ed1f1 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -181,7 +181,7 @@ def test_jumpstart_cache_get_header(): model_type=JumpStartModelType.PROPRIETARY, ) assert ( - "Marketplace model 'ai21-summarization' does not support wildcard version identifier '3.*'. " + "Proprietary model 'ai21-summarization' does not support wildcard version identifier '3.*'. " "You can pin to version '1.1.003'. " "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " "for list of supported model IDs. " in str(e.value) @@ -941,7 +941,7 @@ def test_jumpstart_cache_get_specs(): model_type=JumpStartModelType.PROPRIETARY, ) assert ( - "Marketplace model 'ai21-summarization' does not support wildcard version identifier '3.*'. " + "Proprietary model 'ai21-summarization' does not support wildcard version identifier '3.*'. " "You can pin to version '1.1.003'. " "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " "for list of supported model IDs. " in str(e.value) diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 02b656497b..56ef8a63aa 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -640,6 +640,7 @@ def test_list_jumpstart_proprietary_models( ] assert list_jumpstart_models("model_type == proprietary") == all_prop_model_ids + assert list_jumpstart_models("model_type == marketplace") == all_prop_model_ids assert list_jumpstart_models("model_type == open_weight") == all_open_weight_model_ids assert list_jumpstart_models(list_versions=False) == sorted( From 6fb885ef2fc5139db00cc842fd6c05ef3eff48e8 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Mon, 11 Mar 2024 15:11:11 +0000 Subject: [PATCH 25/30] address comments --- doc/doc_utils/jumpstart_doc_utils.py | 50 ++++++++--------- src/sagemaker/accept_types.py | 2 +- src/sagemaker/content_types.py | 2 +- src/sagemaker/deserializers.py | 2 +- src/sagemaker/instance_types.py | 2 +- src/sagemaker/jumpstart/accessors.py | 6 +-- .../jumpstart/artifacts/instance_types.py | 2 +- src/sagemaker/jumpstart/artifacts/kwargs.py | 7 +-- .../jumpstart/artifacts/model_packages.py | 2 +- src/sagemaker/jumpstart/artifacts/payloads.py | 2 +- .../jumpstart/artifacts/predictors.py | 14 ++--- .../jumpstart/artifacts/resource_names.py | 2 +- .../artifacts/resource_requirements.py | 2 +- src/sagemaker/jumpstart/cache.py | 54 +++++++++++-------- src/sagemaker/jumpstart/constants.py | 4 +- src/sagemaker/jumpstart/enums.py | 2 +- src/sagemaker/jumpstart/exceptions.py | 31 +++++++---- src/sagemaker/jumpstart/factory/estimator.py | 2 +- src/sagemaker/jumpstart/factory/model.py | 18 ++----- src/sagemaker/jumpstart/model.py | 35 ++++++------ src/sagemaker/jumpstart/notebook_utils.py | 12 ++--- src/sagemaker/jumpstart/types.py | 8 +-- src/sagemaker/jumpstart/utils.py | 14 ++--- src/sagemaker/payloads.py | 4 +- src/sagemaker/predictor.py | 2 +- src/sagemaker/resource_requirements.py | 2 +- src/sagemaker/serializers.py | 2 +- .../jumpstart/model/test_jumpstart_model.py | 2 +- .../jumpstart/test_accept_types.py | 8 +-- .../jumpstart/test_content_types.py | 8 +-- .../jumpstart/test_deserializers.py | 8 +-- .../jumpstart/test_default.py | 12 ++--- .../hyperparameters/jumpstart/test_default.py | 8 +-- .../jumpstart/test_validate.py | 12 ++--- .../image_uris/jumpstart/test_common.py | 10 ++-- .../jumpstart/test_instance_types.py | 10 ++-- .../jumpstart/estimator/test_estimator.py | 46 ++++++++-------- .../estimator/test_sagemaker_config.py | 16 +++--- .../sagemaker/jumpstart/model/test_model.py | 50 ++++++++--------- .../jumpstart/model/test_sagemaker_config.py | 16 +++--- .../sagemaker/jumpstart/test_artifacts.py | 4 +- tests/unit/sagemaker/jumpstart/test_cache.py | 4 +- .../sagemaker/jumpstart/test_exceptions.py | 34 ++++++++++++ .../jumpstart/test_notebook_utils.py | 8 +-- .../sagemaker/jumpstart/test_predictor.py | 2 +- tests/unit/sagemaker/jumpstart/utils.py | 12 ++--- .../jumpstart/test_default.py | 6 +-- .../model_uris/jumpstart/test_common.py | 10 ++-- .../jumpstart/test_resource_requirements.py | 8 +-- .../script_uris/jumpstart/test_common.py | 10 ++-- .../serializers/jumpstart/test_serializers.py | 8 +-- 51 files changed, 323 insertions(+), 274 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 917cb9c283..fb80579b5a 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -155,25 +155,26 @@ class Frameworks(str, Enum): } -def get_jumpstart_sdk_manifest(): - url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, SDK_MANIFEST_FILE) +def get_public_s3_json_object(url): with request.urlopen(url) as f: models_manifest = f.read().decode("utf-8") return json.loads(models_manifest) +def get_jumpstart_sdk_manifest(): + return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{SDK_MANIFEST_FILE}") + + def get_proprietary_sdk_manifest(): - url = "{}/{}".format(PROPRIETARY_DOC_BUCKET, PROPRIETARY_SDK_MANIFEST_FILE) - with request.urlopen(url) as f: - models_manifest = f.read().decode("utf-8") - return json.loads(models_manifest) + return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{PROPRIETARY_SDK_MANIFEST_FILE}") -def get_jumpstart_sdk_spec(key): - url = "{}/{}".format(PROPRIETARY_DOC_BUCKET, key) - with request.urlopen(url) as f: - model_spec = f.read().decode("utf-8") - return json.loads(model_spec) +def get_jumpstart_sdk_spec(s3_key: str): + return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{s3_key}") + + +def get_proprietary_sdk_spec(s3_key: str): + return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{s3_key}") def get_model_task(id): @@ -207,18 +208,19 @@ def get_model_source(url): def create_proprietary_model_table(): - marketpkace_content_intro = [] - marketpkace_content_intro.append("\n") - marketpkace_content_intro.append(".. list-table:: Available Proprietary Models\n") - marketpkace_content_intro.append(" :widths: 50 20 20 20 20\n") - marketpkace_content_intro.append(" :header-rows: 1\n") - marketpkace_content_intro.append(" :class: datatable\n") - marketpkace_content_intro.append("\n") - marketpkace_content_intro.append(" * - Model ID\n") - marketpkace_content_intro.append(" - Fine Tunable?\n") - marketpkace_content_intro.append(" - Supported Version\n") - marketpkace_content_intro.append(" - Min SDK Version\n") - marketpkace_content_intro.append(" - Source\n") + marketpkace_content_intro = f""" + .. list-table:: Available Proprietary Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Supported Version + - Min SDK Version + - Source + + """ sdk_manifest = get_proprietary_sdk_manifest() sdk_manifest_top_versions_for_models = {} @@ -234,7 +236,7 @@ def create_proprietary_model_table(): proprietary_content_entries = [] for model in sdk_manifest_top_versions_for_models.values(): - model_spec = get_jumpstart_sdk_spec(model["spec_key"]) + model_spec = get_proprietary_sdk_spec(model["spec_key"]) proprietary_content_entries.append(" * - {}\n".format(model_spec["model_id"])) proprietary_content_entries.append(" - {}\n".format(False)) # TODO: support training proprietary_content_entries.append(" - {}\n".format(model["version"])) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 43abd5d1a1..78aa655e04 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -76,7 +76,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default accept type for the model matching the given arguments. diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index efdbf6846c..46d0361f67 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -76,7 +76,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default content type for the model matching the given arguments. diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 27ae946450..1a4be43897 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -96,7 +96,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index c24fd57bc7..48aaab0ac8 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -35,7 +35,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default instance type for the model matching the given arguments. diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index d49bcb9bb8..35df030ddc 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -200,7 +200,7 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None: def _get_manifest( region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest. @@ -229,7 +229,7 @@ def get_model_header( region: str, model_id: str, version: str, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelHeader: """Returns model header from JumpStart models cache. @@ -254,7 +254,7 @@ def get_model_specs( model_id: str, version: str, s3_client: Optional[boto3.client] = None, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index dfe21f21a9..608303c5e6 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -39,7 +39,7 @@ def _retrieve_default_instance_type( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default instance type for the model. diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index f057864cdd..c15f686805 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -36,7 +36,7 @@ def _retrieve_model_init_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> dict: """Retrieves kwargs for `Model`. @@ -81,9 +81,6 @@ def _retrieve_model_init_kwargs( if model_specs.inference_enable_network_isolation is not None: kwargs.update({"enable_network_isolation": model_specs.inference_enable_network_isolation}) - if model_type == JumpStartModelType.PROPRIETARY: - kwargs.update({"enable_network_isolation": True}) - return kwargs @@ -95,7 +92,7 @@ def _retrieve_model_deploy_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> dict: """Retrieves kwargs for `Model.deploy`. diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index c87088f4fb..5c8a2488c6 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -36,7 +36,7 @@ def _retrieve_model_package_arn( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[str]: """Retrieves associated model pacakge arn for the model. diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 6db511f4db..0424145119 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -36,7 +36,7 @@ def _retrieve_example_payloads( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[Dict[str, JumpStartSerializablePayload]]: """Returns example payloads. diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index c16ba4eaac..e9e0e8dfde 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -77,7 +77,7 @@ def _retrieve_default_deserializer( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> BaseDeserializer: """Retrieves the default deserializer for the model. @@ -123,7 +123,7 @@ def _retrieve_default_serializer( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> BaseSerializer: """Retrieves the default serializer for the model. @@ -168,7 +168,7 @@ def _retrieve_deserializer_options( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model. @@ -283,7 +283,7 @@ def _retrieve_default_content_type( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default content type for the model. @@ -334,7 +334,7 @@ def _retrieve_default_accept_type( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default accept type for the model. @@ -385,7 +385,7 @@ def _retrieve_supported_accept_types( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[str]: """Retrieves the supported accept types for the model. @@ -436,7 +436,7 @@ def _retrieve_supported_content_types( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[str]: """Retrieves the supported content types for the model. diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index eae4c1a300..60af520a6e 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -33,7 +33,7 @@ def _retrieve_resource_name_base( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns default resource name. diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 5464c30937..9f01a7af77 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -51,7 +51,7 @@ def _retrieve_default_resources( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, ) -> ResourceRequirements: diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index a4514a97a9..f44d44ce78 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -125,7 +125,7 @@ def __init__( self._manifest_file_s3_key = manifest_file_s3_key self._proprietary_manifest_s3_key = proprietary_manifest_s3_key self._manifest_file_s3_map = { - JumpStartModelType.OPEN_WEIGHT: self._manifest_file_s3_key, + JumpStartModelType.OPEN_WEIGHTS: self._manifest_file_s3_key, JumpStartModelType.PROPRIETARY: self._proprietary_manifest_s3_key, } self.s3_bucket_name = ( @@ -166,9 +166,7 @@ def set_manifest_file_s3_key( property_name = file_mapping.get(file_type) if not property_name: raise ValueError( - f"Bad value when setting manifest '{file_type}': must be in" - f" {JumpStartS3FileType.OPEN_WEIGHT_MANIFEST}" - f" {JumpStartS3FileType.PROPRIETARY_MANIFEST}" + self._file_type_error_msg(file_type, manifest_only=True) ) if key != property_name: setattr(self, property_name, key) @@ -183,9 +181,7 @@ def get_manifest_file_s3_key( if file_type == JumpStartS3FileType.PROPRIETARY_MANIFEST: return self._proprietary_manifest_s3_key raise ValueError( - f"Bad value when getting manifest '{file_type}':" - f"must be in {JumpStartS3FileType.OPEN_WEIGHT_MANIFEST}" - f"{JumpStartS3FileType.PROPRIETARY_MANIFEST}" + self._file_type_error_msg(file_type, manifest_only=True) ) def set_s3_bucket_name(self, s3_bucket_name: str) -> None: @@ -198,11 +194,24 @@ def get_bucket(self) -> str: """Return bucket used for cache.""" return self.s3_bucket_name + def _file_type_error_msg(self, file_type: str, manifest_only: bool = False) -> str: + """Return error message for bad model type.""" + if manifest_only: + return ( + f"Bad value when getting manifest '{file_type}': " + f"must be in {JumpStartS3FileType.OPEN_WEIGHT_MANIFEST} " + f"{JumpStartS3FileType.PROPRIETARY_MANIFEST}." + ) + return ( + f"Bad value when getting manifest '{file_type}': " + f"must be in '{' '.join([e.name for e in JumpStartS3FileType])}'." + ) + def _model_id_retrieval_function( self, key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 - model_type: JumpStartModelType + model_type: JumpStartModelType, ) -> JumpStartVersionedModelId: """Return model ID and version in manifest that matches semantic version/id. @@ -278,7 +287,7 @@ def _model_id_retrieval_function( ) other_model_id_version = None - if model_type == JumpStartModelType.OPEN_WEIGHT: + if model_type == JumpStartModelType.OPEN_WEIGHTS: other_model_id_version = self._select_version( model_id, "*", versions_incompatible_with_sagemaker, model_type ) # all versions here are incompatible with sagemaker @@ -310,9 +319,11 @@ def _get_open_weight_manifest_key_from_model_id( key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 ) -> JumpStartVersionedModelId: - """Retrieve model manifest key for open source model, by filtering supported versions.""" + """For open weights models, retrieve model manifest key for open source model. + + Filters models list by supported versions.""" return self._model_id_retrieval_function( - key, value, model_type=JumpStartModelType.OPEN_WEIGHT + key, value, model_type=JumpStartModelType.OPEN_WEIGHTS ) def _get_proprietary_manifest_key_from_model_id( @@ -320,7 +331,9 @@ def _get_proprietary_manifest_key_from_model_id( key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 ) -> JumpStartVersionedModelId: - """Retrieve model manifest key for proprietary model, by filtering supported versions.""" + """For proprietary models, retrieve model manifest key for proprietary model. + + Filters models list by supported versions.""" return self._model_id_retrieval_function( key, value, model_type=JumpStartModelType.PROPRIETARY ) @@ -423,14 +436,12 @@ def _retrieval_function( utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) return JumpStartCachedS3ContentValue(formatted_content=model_specs) raise ValueError( - f"Bad value for key '{key}': must be in" - f"{JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, JumpStartS3FileType.OPEN_WEIGHT_SPECS}" - f"{JumpStartS3FileType.PROPRIETARY_SPECS, JumpStartS3FileType.PROPRIETARY_MANIFEST}" + self._file_type_error_msg(file_type) ) def get_manifest( self, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest.""" manifest_dict = self._s3_cache.get( @@ -444,7 +455,7 @@ def get_header( self, model_id: str, semantic_version_str: str, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelHeader: """Return header for a given JumpStart model ID and semantic version. @@ -463,7 +474,7 @@ def _select_version( model_id: str, version_str: str, available_versions: List[str], - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[str]: """Perform semantic version search on available versions. @@ -501,7 +512,7 @@ def _get_header_impl( model_id: str, semantic_version_str: str, attempt: int = 0, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS ) -> JumpStartModelHeader: """Lower-level function to return header. @@ -513,7 +524,7 @@ def _get_header_impl( header. attempt (int): attempt number at retrieving a header. """ - if model_type == JumpStartModelType.OPEN_WEIGHT: + if model_type == JumpStartModelType.OPEN_WEIGHTS: versioned_model_id = self._open_weight_model_id_manifest_key_cache.get( JumpStartVersionedModelId(model_id, semantic_version_str) )[0] @@ -540,7 +551,7 @@ def get_specs( self, model_id: str, version_str: str, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS ) -> JumpStartModelSpecs: """Return specs for a given JumpStart model ID and semantic version. @@ -548,6 +559,7 @@ def get_specs( model_id (str): model ID for which to get specs. semantic_version_str (str): The semantic version for which to get specs. + model_type (JumpStartModelType): The type of the model of interest. """ header = self.get_header(model_id, version_str, model_type) spec_key = header.spec_key diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 7eb21073e6..be66e8968e 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -219,12 +219,12 @@ } MODEL_TYPE_TO_MANIFEST_MAP: Dict[Type[JumpStartModelType], Type[JumpStartS3FileType]] = { - JumpStartModelType.OPEN_WEIGHT: JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, + JumpStartModelType.OPEN_WEIGHTS: JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, JumpStartModelType.PROPRIETARY: JumpStartS3FileType.PROPRIETARY_MANIFEST, } MODEL_TYPE_TO_SPECS_MAP: Dict[Type[JumpStartModelType], Type[JumpStartS3FileType]] = { - JumpStartModelType.OPEN_WEIGHT: JumpStartS3FileType.OPEN_WEIGHT_SPECS, + JumpStartModelType.OPEN_WEIGHTS: JumpStartS3FileType.OPEN_WEIGHT_SPECS, JumpStartModelType.PROPRIETARY: JumpStartS3FileType.PROPRIETARY_SPECS, } diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index b28f202fe7..62de9fc3c3 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -41,7 +41,7 @@ class JumpStartModelType(str, Enum): Proprietary model refers to external provider owned Marketplace models. """ - OPEN_WEIGHT = "open_weight" + OPEN_WEIGHTS = "open_weights" PROPRIETARY = "proprietary" diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index d98e98b9ab..4bc454c139 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -14,7 +14,12 @@ from __future__ import absolute_import from typing import List, Optional -from sagemaker.jumpstart.constants import MODEL_ID_LIST_WEB_URL, JumpStartScriptScope +from botocore.exceptions import ClientError + +from sagemaker.jumpstart.constants import ( + MODEL_ID_LIST_WEB_URL, + JumpStartScriptScope, +) NO_AVAILABLE_INSTANCES_ERROR_MSG = ( "No instances available in {region} that can support model ID '{model_id}'. " @@ -28,7 +33,7 @@ INVALID_MODEL_ID_ERROR_MSG = ( "Invalid model ID: '{model_id}'. Please visit " - f"{MODEL_ID_LIST_WEB_URL} for list of supported model IDs. " + f"{MODEL_ID_LIST_WEB_URL} for a list of valid model IDs. " "The module `sagemaker.jumpstart.notebook_utils` contains utilities for " "fetching model IDs. We recommend upgrading to the latest version of sagemaker " "to get access to the most models." @@ -66,7 +71,7 @@ def get_proprietary_model_subscription_msg( return ( f"INFO: Using proprietary model '{model_id}'. " - f"Please make sure to subscribe to the model from {subscription_link}" + f"Please make sure to subscribe to the model on {subscription_link}" ) @@ -80,7 +85,7 @@ def get_wildcard_proprietary_model_version_msg( ) if len(available_versions) > 0: msg += f"You can pin to version '{available_versions[0]}'. " - msg += f"{MODEL_ID_LIST_WEB_URL} for list of supported model IDs. " + msg += f"{MODEL_ID_LIST_WEB_URL} for a list of valid model IDs. " return msg @@ -96,6 +101,15 @@ def get_old_model_version_msg( ) +def get_proprietary_model_subscription_error(error: ClientError, subscription_link: str) -> None: + """Returns customer-facing message associated with a Marketplace subscription error.""" + + error_code = error.response["Error"]["Code"] + error_message = error.response["Error"]["Message"] + if error_code == "ValidationException" and "not subscribed" in error_message: + raise MarketplaceModelSubscriptionError(subscription_link) + + class JumpStartHyperparametersError(ValueError): """Exception raised for bad hyperparameters of a JumpStart model.""" @@ -213,11 +227,8 @@ def __init__( if message: self.message = message else: - if not model_subscription_link: - raise RuntimeError("Must specify `model_subscription_link` in arguments.") - self.message = ( - f"You have not subscribed to this Marketplace model. " - f"Please subscribe following this link {model_subscription_link}" - ) + self.message = "You have not subscribed to this Marketplace model. " + if model_subscription_link: + self.message += f"Please subscribe following this link {model_subscription_link}" super().__init__(self.message) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index ec99210b54..7c20c281f5 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -77,7 +77,7 @@ def get_init_kwargs( model_id: str, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index d26d2bded0..c12a14c4a4 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -71,7 +71,7 @@ def get_default_predictor( tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, sagemaker_session: Session, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. @@ -165,16 +165,6 @@ def _add_model_version_to_kwargs( return kwargs -def _log_model_type(kwargs: JumpStartModelInitKwargs) -> None: - """Log the model type being used""" - if kwargs.model_type == JumpStartModelType.PROPRIETARY: - JUMPSTART_LOGGER.info( - "Proprietary model %s of version %s is being used.", - kwargs.model_id, - kwargs.model_version, - ) - - def _add_vulnerable_and_deprecated_status_to_kwargs( kwargs: JumpStartModelInitKwargs, ) -> JumpStartModelInitKwargs: @@ -536,7 +526,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -703,7 +693,7 @@ def get_init_kwargs( model_id: str, model_from_estimator: bool = False, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, instance_type: Optional[str] = None, @@ -790,6 +780,4 @@ def get_init_kwargs( model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) - _log_model_type(kwargs=model_init_kwargs) - return model_init_kwargs diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 0145a26fe4..de1d06a2a8 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -21,12 +21,13 @@ from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer from sagemaker.base_serializers import BaseSerializer +from sagemaker.enums import EndpointType from sagemaker.explainer.explainer_config import ExplainerConfig from sagemaker.jumpstart.accessors import JumpStartModelsAccessor from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import ( INVALID_MODEL_ID_ERROR_MSG, - MarketplaceModelSubscriptionError, + get_proprietary_model_subscription_error, get_proprietary_model_subscription_msg, ) from sagemaker.jumpstart.factory.model import ( @@ -56,7 +57,6 @@ from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements -from sagemaker.enums import EndpointType class JumpStartModel(Model): @@ -347,7 +347,7 @@ def _validate_model_id_and_type(): self.model_package_arn = model_init_kwargs.model_package_arn def log_subscription_warning(self) -> None: - """Logs customer facing message for subscribe to the proprietary model.""" + """Log message prompting the customer to subscribe to the proprietary model.""" subscription_link = verify_model_region_and_return_specs( region=self.region, model_id=self.model_id, @@ -626,21 +626,26 @@ def deploy( endpoint_type=endpoint_type, model_type=self.model_type, ) + if ( + self.model_type == JumpStartModelType.PROPRIETARY + and endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED + ): + raise ValueError( + "EndpointType.INFERENCE_COMPONENT_BASED is not supported for Proprietary models." + ) + try: predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) except ClientError as e: - error_code = e.response["Error"]["Code"] - error_message = e.response["Error"]["Message"] - if error_code == "ValidationException" and "not subscribed" in error_message: - subscription_link = verify_model_region_and_return_specs( - region=self.region, - model_id=self.model_id, - version=self.model_version, - model_type=self.model_type, - scope=JumpStartScriptScope.INFERENCE, - sagemaker_session=self.sagemaker_session, - ).model_subscription_link - raise MarketplaceModelSubscriptionError(subscription_link) + subscription_link = verify_model_region_and_return_specs( + region=self.region, + model_id=self.model_id, + version=self.model_version, + model_type=self.model_type, + scope=JumpStartScriptScope.INFERENCE, + sagemaker_session=self.sagemaker_session, + ).model_subscription_link + get_proprietary_model_subscription_error(e, subscription_link) raise # If no predictor class was passed, add defaults to predictor diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 62806de6e6..41c03c0f62 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -139,7 +139,7 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: return framework, task, name -def extract_model_type(spec_key: str) -> str: +def extract_model_type_filter_representation(spec_key: str) -> str: """Parses model spec key, determine if the model is proprietary or open weight. Args: @@ -150,7 +150,7 @@ def extract_model_type(spec_key: str) -> str: if model_spec_prefix == PROPRIETARY_MODEL_SPEC_PREFIX: return JumpStartModelType.PROPRIETARY.value - return JumpStartModelType.OPEN_WEIGHT.value + return JumpStartModelType.OPEN_WEIGHTS.value def list_jumpstart_tasks( # pylint: disable=redefined-builtin @@ -342,7 +342,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin open_weight_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( region=region, s3_client=sagemaker_session.s3_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) models_manifest_list = open_weight_manifest_list + prop_models_manifest_list @@ -359,10 +359,10 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin model_filter = operator.unresolved_value key = model_filter.key all_keys.add(key) - if model_filter.key == SpecialSupportedFilterKeys.MODEL_TYPE and model_filter.value in [ + if model_filter.key == SpecialSupportedFilterKeys.MODEL_TYPE and model_filter.value in { "marketplace", "proprietary", - ]: + }: model_filter.set_value(JumpStartModelType.PROPRIETARY.value) model_filters.add(model_filter) @@ -403,7 +403,7 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, if is_model_type_filter: manifest_specs_cached_values[ SpecialSupportedFilterKeys.MODEL_TYPE - ] = extract_model_type(model_manifest.spec_key) + ] = extract_model_type_filter_representation(model_manifest.spec_key) if Version(model_manifest.min_version) > Version(get_sagemaker_version()): return None diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 045187551d..9af44779f3 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1099,7 +1099,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -1206,7 +1206,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -1340,7 +1340,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, role: Optional[str] = None, @@ -1485,7 +1485,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, inputs: Optional[Union[str, Dict, Any, Any]] = None, wait: Optional[bool] = None, diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 3b1b101c29..71a8067a6f 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -540,7 +540,7 @@ def verify_model_region_and_return_specs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHT, + model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -759,13 +759,13 @@ def validate_model_id_and_get_type( def _get_model_type( model_id: str, - open_weight_models: Set[str], - proprietary_models: Set[str], + open_weights_model_ids: Set[str], + proprietary_model_ids: Set[str], script: enums.JumpStartScriptScope, ) -> Optional[enums.JumpStartModelType]: - if model_id in open_weight_models: - return enums.JumpStartModelType.OPEN_WEIGHT - if model_id in proprietary_models: + if model_id in open_weights_model_ids: + return enums.JumpStartModelType.OPEN_WEIGHTS + if model_id in proprietary_model_ids: if script == enums.JumpStartScriptScope.INFERENCE: return enums.JumpStartModelType.PROPRIETARY raise ValueError(f"Unsupported script for Marketplace models: {script}") @@ -780,7 +780,7 @@ def _get_model_type( region = region or constants.JUMPSTART_DEFAULT_REGION_NAME model_version = model_version or "*" models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( - region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.OPEN_WEIGHT + region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.OPEN_WEIGHTS ) open_weight_model_id_set = {model.model_id for model in models_manifest_list} diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index d21ca32480..de33f61b82 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -32,7 +32,7 @@ def retrieve_all_examples( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -123,7 +123,7 @@ def retrieve_example( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, - model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHT, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 4d564c2afc..6f846bba65 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -42,7 +42,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Predictor: """Retrieves the default predictor for the model matching the given arguments. diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index 342770a410..df14ac558f 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -34,7 +34,7 @@ def retrieve_default( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, ) -> ResourceRequirements: diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index bcd34f722e..aefb52bd97 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -94,7 +94,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> BaseSerializer: """Retrieves the default serializer for the model matching the given arguments. diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 3e60051529..5205765e2f 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -273,7 +273,7 @@ def test_proprietary_jumpstart_model(setup): model = JumpStartModel( model_id=model_id, - model_version="*", + model_version="2.0.004", role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), ) diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 1c12f777a0..49c18beec2 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -37,7 +37,7 @@ def test_jumpstart_default_accept_types( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -55,7 +55,7 @@ def test_jumpstart_default_accept_types( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @@ -70,7 +70,7 @@ def test_jumpstart_supported_accept_types( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -91,5 +91,5 @@ def test_jumpstart_supported_accept_types( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index ae698a7d94..7765d6eaad 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -36,7 +36,7 @@ def test_jumpstart_default_content_types( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -54,7 +54,7 @@ def test_jumpstart_default_content_types( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @@ -69,7 +69,7 @@ def test_jumpstart_supported_content_types( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -89,5 +89,5 @@ def test_jumpstart_supported_content_types( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index 4b06ac8c4e..5328533da5 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -38,7 +38,7 @@ def test_jumpstart_default_deserializers( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -56,7 +56,7 @@ def test_jumpstart_default_deserializers( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @@ -71,7 +71,7 @@ def test_jumpstart_deserializer_options( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -96,5 +96,5 @@ def test_jumpstart_deserializer_options( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index 2ee9ecec38..38cc5ebbf3 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -34,7 +34,7 @@ def test_jumpstart_default_environment_variables( ): patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -58,7 +58,7 @@ def test_jumpstart_default_environment_variables( model_id=model_id, version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -82,7 +82,7 @@ def test_jumpstart_default_environment_variables( model_id=model_id, version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -119,7 +119,7 @@ def test_jumpstart_sdk_environment_variables( ): patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -144,7 +144,7 @@ def test_jumpstart_sdk_environment_variables( model_id=model_id, version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -169,7 +169,7 @@ def test_jumpstart_sdk_environment_variables( model_id=model_id, version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index babae7f86c..a13fba87ae 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -34,7 +34,7 @@ def test_jumpstart_default_hyperparameters( ): patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -52,7 +52,7 @@ def test_jumpstart_default_hyperparameters( model_id=model_id, version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -70,7 +70,7 @@ def test_jumpstart_default_hyperparameters( model_id=model_id, version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -96,7 +96,7 @@ def test_jumpstart_default_hyperparameters( model_id=model_id, version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index f3ef886621..7a5df4ac93 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -112,7 +112,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): return spec patched_get_model_specs.side_effect = add_options_to_hyperparameter - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -144,7 +144,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -424,7 +424,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): return spec patched_get_model_specs.side_effect = add_options_to_hyperparameter - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -450,7 +450,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -484,7 +484,7 @@ def test_jumpstart_validate_all_hyperparameters( ): patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -512,7 +512,7 @@ def test_jumpstart_validate_all_hyperparameters( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 1171161be0..6c80c97f33 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -35,7 +35,7 @@ def test_jumpstart_common_image_uri( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -54,7 +54,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -75,7 +75,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -96,7 +96,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -117,7 +117,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index 298a1e9eb2..982c7f1702 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -28,7 +28,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_model_id_and_get_type): patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "huggingface-eqa-bert-base-cased", "*" region = "us-west-2" @@ -50,7 +50,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -69,7 +69,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -94,7 +94,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -121,7 +121,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 1ca13cfed9..4fa18f31aa 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -84,7 +84,7 @@ def test_non_prepacked( mock_jumpstart_model_factory_logger: mock.Mock, mock_jumpstart_estimator_factory_logger: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "9876" @@ -94,7 +94,7 @@ def test_non_prepacked( mock_get_model_specs.side_effect = get_special_model_spec - mock_get_model_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_get_model_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session @@ -207,7 +207,7 @@ def test_prepacked( ): mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -310,7 +310,7 @@ def test_gated_model_s3_uri( mock_timestamp.return_value = "8675309" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-gated-artifact-trainable-model", "*" @@ -448,7 +448,7 @@ def test_gated_model_non_model_package_s3_uri( mock_get_jumpstart_gated_content_bucket.return_value = "top-secret-private-models-bucket" mock_timestamp.return_value = "8675309" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-gated-artifact-non-model-package-trainable-model", "*" @@ -601,7 +601,7 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS with pytest.raises(ValueError) as e: JumpStartEstimator(model_id=model_id, region="eu-north-1") @@ -628,7 +628,7 @@ def test_deprecated( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "deprecated_model", "*" @@ -661,7 +661,7 @@ def test_vulnerable( mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "vulnerable_model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -810,7 +810,7 @@ def evaluate_estimator_workflow_with_kwargs( mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "js-trainable-model", "*" @@ -921,7 +921,7 @@ def test_jumpstart_estimator_tags_disabled( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -960,7 +960,7 @@ def test_jumpstart_estimator_tags( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1004,7 +1004,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( mock_attach: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS get_model_id_version_from_training_job.return_value = ( "js-trainable-model-prepacked", @@ -1047,7 +1047,7 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( mock_attach: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS get_model_id_version_from_training_job.side_effect = ValueError() @@ -1126,7 +1126,7 @@ def test_validate_model_id_and_get_type( mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS JumpStartEstimator(model_id="valid_model_id") mock_validate_model_id_and_get_type.return_value = False @@ -1158,7 +1158,7 @@ def test_no_predictor_returns_default_predictor( mock_get_default_predictor.return_value = default_predictor_with_presets - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1217,7 +1217,7 @@ def test_no_predictor_yes_async_inference_config( mock_get_default_predictor.return_value = default_predictor_with_presets - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1267,7 +1267,7 @@ def test_yes_predictor_returns_unmodified_predictor( mock_get_default_predictor.return_value = default_predictor_with_presets - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1316,7 +1316,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( mock_supports_incremental_training: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_estimator_deploy.return_value = default_predictor @@ -1370,7 +1370,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( mock_supports_incremental_training: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_estimator_deploy.return_value = default_predictor @@ -1419,7 +1419,7 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "3456" @@ -1480,7 +1480,7 @@ def test_training_passes_role_to_deploy( mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "3456" @@ -1560,7 +1560,7 @@ def test_training_passes_session_to_deploy( mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "3456" @@ -1711,7 +1711,7 @@ def test_model_artifact_variant_estimator( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model-artifact-variant-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py index 90f0472924..073921d5ba 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py @@ -104,7 +104,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -161,7 +161,7 @@ def test_without_arg_overwrites_with_kwarg_collisions_with_config( mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -234,7 +234,7 @@ def test_with_arg_overwrites_with_kwarg_collisions_with_config( mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -316,7 +316,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_with_config( mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -393,7 +393,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -454,7 +454,7 @@ def test_without_arg_overwrites_with_kwarg_collisions_without_config( mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_get_caller_identity_arn.return_value = execution_role model_id, _ = "js-trainable-model", "*" @@ -525,7 +525,7 @@ def test_with_arg_overwrites_with_kwarg_collisions_without_config( ): mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -599,7 +599,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( ): mock_estimator_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 3d7c7cac7d..ba4ba0bb13 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -75,7 +75,7 @@ def test_non_prepacked( mock_sagemaker_timestamp.return_value = "7777" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -149,7 +149,7 @@ def test_non_prepacked_inference_component_based_endpoint( mock_sagemaker_timestamp.return_value = "7777" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = ( @@ -229,7 +229,7 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom mock_sagemaker_timestamp.return_value = "7777" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -300,7 +300,7 @@ def test_prepacked( ): mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -364,7 +364,7 @@ def test_no_compiled_model_warning_log_js_models( mock_timestamp.return_value = "1234" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "gated_llama_neuron_model", "*" @@ -398,7 +398,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( ): mock_timestamp.return_value = "1234" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "gated_variant-model", "*" @@ -480,7 +480,7 @@ def test_proprietary_model_endpoint( predictor_cls=Predictor, role=execution_role, sagemaker_session=sagemaker_session, - enable_network_isolation=True, + enable_network_isolation=False, ) model.deploy() @@ -513,7 +513,7 @@ def test_deprecated( ): mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "deprecated_model", "*" @@ -538,7 +538,7 @@ def test_vulnerable( mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_model_deploy.return_value = default_predictor @@ -623,7 +623,7 @@ def evaluate_model_workflow_with_kwargs( mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_session.return_value = sagemaker_session @@ -726,7 +726,7 @@ def test_validate_model_id_and_get_type( mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS JumpStartModel(model_id="valid_model_id") mock_validate_model_id_and_get_type.return_value = False @@ -753,7 +753,7 @@ def test_no_predictor_returns_default_predictor( mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -775,7 +775,7 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -800,7 +800,7 @@ def test_no_predictor_yes_async_inference_config( mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -836,7 +836,7 @@ def test_yes_predictor_returns_default_predictor( mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -909,7 +909,7 @@ def test_model_id_not_found_refeshes_cache_inference( mock_validate_model_id_and_get_type.side_effect = [ False, - JumpStartModelType.OPEN_WEIGHT, + JumpStartModelType.OPEN_WEIGHTS, ] JumpStartModel( model_id=model_id, @@ -944,7 +944,7 @@ def test_jumpstart_model_tags( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "env-var-variant-model", "*" @@ -980,7 +980,7 @@ def test_jumpstart_model_tags_disabled( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "env-var-variant-model", "*" @@ -1012,7 +1012,7 @@ def test_jumpstart_model_package_arn( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-package-arn", "*" @@ -1046,7 +1046,7 @@ def test_jumpstart_model_package_arn_override( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS # arbitrary model without model packarn arn model_id, _ = "js-trainable-model", "*" @@ -1090,7 +1090,7 @@ def test_jumpstart_model_package_arn_unsupported_region( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-package-arn", "*" @@ -1127,7 +1127,7 @@ def test_model_data_s3_prefix_override( mock_sagemaker_timestamp.return_value = "7777" - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1189,7 +1189,7 @@ def test_model_data_s3_prefix_model( ): mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model_data_s3_prefix_model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1233,7 +1233,7 @@ def test_model_artifact_variant_model( ): mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model-artifact-variant-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1296,7 +1296,7 @@ def test_model_registry_accept_and_response_types( ): mock_model_deploy.return_value = default_predictor - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model_data_s3_prefix_model", "*" mock_get_model_specs.side_effect = get_special_model_spec diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py index e32d063655..70409704e6 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py @@ -79,7 +79,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_retrieve_kwargs.return_value = {} @@ -120,7 +120,7 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -166,7 +166,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -212,7 +212,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -260,7 +260,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -305,7 +305,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -353,7 +353,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -394,7 +394,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_validate_model_id_and_get_type: mock.Mock, ): - mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index a72544464c..21112926a5 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -337,7 +337,7 @@ def test_retrieve_model_package_arn( self, patched_get_model_specs: Mock, patched_validate_model_id_and_get_type: Mock ): patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "variant-model" region = "us-west-2" @@ -447,7 +447,7 @@ def test_retrieve_uri_from_gated_bucket( self, patched_get_model_specs, patched_validate_model_id_and_get_type ): patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "private-model" region = "us-west-2" diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index ed660ed1f1..50fe6da0a6 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -184,7 +184,7 @@ def test_jumpstart_cache_get_header(): "Proprietary model 'ai21-summarization' does not support wildcard version identifier '3.*'. " "You can pin to version '1.1.003'. " "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " - "for list of supported model IDs. " in str(e.value) + "for a list of valid model IDs. " in str(e.value) ) with pytest.raises(KeyError) as e: @@ -944,7 +944,7 @@ def test_jumpstart_cache_get_specs(): "Proprietary model 'ai21-summarization' does not support wildcard version identifier '3.*'. " "You can pin to version '1.1.003'. " "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " - "for list of supported model IDs. " in str(e.value) + "for a list of valid model IDs. " in str(e.value) ) with pytest.raises(KeyError): diff --git a/tests/unit/sagemaker/jumpstart/test_exceptions.py b/tests/unit/sagemaker/jumpstart/test_exceptions.py index 555099a753..2307d22474 100644 --- a/tests/unit/sagemaker/jumpstart/test_exceptions.py +++ b/tests/unit/sagemaker/jumpstart/test_exceptions.py @@ -11,10 +11,15 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +import pytest + +from botocore.exceptions import ClientError from sagemaker.jumpstart.exceptions import ( get_wildcard_model_version_msg, get_old_model_version_msg, + get_proprietary_model_subscription_error, + MarketplaceModelSubscriptionError, ) @@ -35,3 +40,32 @@ def test_get_old_model_version_msg(): "Note that models may have different input/output signatures after a major " "version upgrade." == get_old_model_version_msg("mother_of_all_models", "1.0.0", "1.2.3") ) + + +def test_get_marketplace_subscription_error(): + error = ClientError( + error_response={ + "Error": { + "Code": "ValidationException", + "Message": "Caller is not subscribed to the Marketplace listing.", + }, + }, + operation_name="mock-operation", + ) + with pytest.raises(MarketplaceModelSubscriptionError): + get_proprietary_model_subscription_error(error, subscription_link="mock-link") + + error = ClientError( + error_response={ + "Error": { + "Code": "UnknownException", + "Message": "Unknown error raised.", + }, + }, + operation_name="mock-operation", + ) + + try: + get_proprietary_model_subscription_error(error, subscription_link="mock-link") + except MarketplaceModelSubscriptionError: + pytest.fail("MarketplaceModelSubscriptionError should not be raised for unknown error.") diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 56ef8a63aa..059cd7ccad 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -424,7 +424,7 @@ def test_list_jumpstart_models_region( patched_get_manifest.assert_called_with( region="some-region", s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -641,7 +641,7 @@ def test_list_jumpstart_proprietary_models( assert list_jumpstart_models("model_type == proprietary") == all_prop_model_ids assert list_jumpstart_models("model_type == marketplace") == all_prop_model_ids - assert list_jumpstart_models("model_type == open_weight") == all_open_weight_model_ids + assert list_jumpstart_models("model_type == open_weights") == all_open_weight_model_ids assert list_jumpstart_models(list_versions=False) == sorted( all_prop_model_ids + all_open_weight_model_ids @@ -720,7 +720,7 @@ def test_get_model_url( ): patched_get_model_specs.side_effect = get_prototype_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( region ) @@ -749,5 +749,5 @@ def test_get_model_url( version=version, region="us-west-2", s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 51148021c1..52f28f2da1 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -127,7 +127,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=mock_session, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index cfbc4d82c9..65fe10f7a7 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -48,7 +48,7 @@ def get_header_from_base_header( model_id: str = None, semantic_version_str: str = None, version: str = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelHeader: if version and semantic_version_str: @@ -88,7 +88,7 @@ def get_header_from_base_header( def get_prototype_manifest( region: str = JUMPSTART_DEFAULT_REGION_NAME, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[JumpStartModelHeader]: if model_type == JumpStartModelType.PROPRIETARY: return [JumpStartModelHeader(spec) for spec in BASE_PROPRIETARY_MANIFEST] @@ -104,7 +104,7 @@ def get_prototype_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. @@ -120,7 +120,7 @@ def get_special_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. This is reserved @@ -136,7 +136,7 @@ def get_special_model_spec_for_inference_component_based_endpoint( model_id: str = None, version: str = None, s3_client: boto3.client = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID and adding @@ -160,7 +160,7 @@ def get_spec_from_base_spec( version_str: str = None, version: str = None, s3_client: boto3.client = None, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: if version and version_str: diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index ecfaec7214..608a32a005 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -33,7 +33,7 @@ def test_jumpstart_default_metric_definitions( ): patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -56,7 +56,7 @@ def test_jumpstart_default_metric_definitions( model_id=model_id, version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -76,7 +76,7 @@ def test_jumpstart_default_metric_definitions( model_id=model_id, version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 2187247702..8d75731b06 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -36,7 +36,7 @@ def test_jumpstart_common_model_uri( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -52,7 +52,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -70,7 +70,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -89,7 +89,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -108,7 +108,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index c0369a595b..7b5e7a598d 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -34,7 +34,7 @@ def test_jumpstart_resource_requirements( ): patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS region = "us-west-2" mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -55,7 +55,7 @@ def test_jumpstart_resource_requirements( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -116,7 +116,7 @@ def test_jumpstart_no_supported_resource_requirements( ): patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "no-supported-instance-types-model", "*" region = "us-west-2" @@ -137,7 +137,7 @@ def test_jumpstart_no_supported_resource_requirements( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 7b3ad26e15..c797ba3559 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -36,7 +36,7 @@ def test_jumpstart_common_script_uri( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -52,7 +52,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -70,7 +70,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -89,7 +89,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -108,7 +108,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index 10d09c973c..c2253726bf 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -35,7 +35,7 @@ def test_jumpstart_default_serializers( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -55,7 +55,7 @@ def test_jumpstart_default_serializers( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -72,7 +72,7 @@ def test_jumpstart_serializer_options( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHT + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -99,5 +99,5 @@ def test_jumpstart_serializer_options( model_id=model_id, version=model_version, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHT, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) From 40542b75e4a7b882385f1f300184896ddee5d7c3 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Mon, 11 Mar 2024 15:48:04 +0000 Subject: [PATCH 26/30] fix: docstyle and flake8 --- doc/doc_utils/jumpstart_doc_utils.py | 4 ++-- src/sagemaker/jumpstart/cache.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index fb80579b5a..6958b36555 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -208,12 +208,12 @@ def get_model_source(url): def create_proprietary_model_table(): - marketpkace_content_intro = f""" + marketpkace_content_intro = """ .. list-table:: Available Proprietary Models :widths: 50 20 20 20 20 :header-rows: 1 :class: datatable - + * - Model ID - Fine Tunable? - Supported Version diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index f44d44ce78..f337edb4c2 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -321,7 +321,8 @@ def _get_open_weight_manifest_key_from_model_id( ) -> JumpStartVersionedModelId: """For open weights models, retrieve model manifest key for open source model. - Filters models list by supported versions.""" + Filters models list by supported versions. + """ return self._model_id_retrieval_function( key, value, model_type=JumpStartModelType.OPEN_WEIGHTS ) @@ -333,7 +334,8 @@ def _get_proprietary_manifest_key_from_model_id( ) -> JumpStartVersionedModelId: """For proprietary models, retrieve model manifest key for proprietary model. - Filters models list by supported versions.""" + Filters models list by supported versions. + """ return self._model_id_retrieval_function( key, value, model_type=JumpStartModelType.PROPRIETARY ) From abfadcaa93c4880260ed6272bfbb0661c5861a93 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Mon, 11 Mar 2024 16:36:59 +0000 Subject: [PATCH 27/30] address more comments and fix doc --- doc/doc_utils/jumpstart_doc_utils.py | 2 +- src/sagemaker/jumpstart/filters.py | 7 +++++++ src/sagemaker/jumpstart/model.py | 2 +- src/sagemaker/jumpstart/notebook_utils.py | 5 ++--- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 6958b36555..51857526b9 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -244,7 +244,7 @@ def create_proprietary_model_table(): proprietary_content_entries.append( " - `{} <{}>`__ |external-link|\n".format("Source", model_spec.get("url")) ) - return marketpkace_content_intro + proprietary_content_entries + ["\n"] + return [marketpkace_content_intro] + proprietary_content_entries + ["\n"] def create_jumpstart_model_table(): diff --git a/src/sagemaker/jumpstart/filters.py b/src/sagemaker/jumpstart/filters.py index 220a1bc9a2..fc5113315d 100644 --- a/src/sagemaker/jumpstart/filters.py +++ b/src/sagemaker/jumpstart/filters.py @@ -52,6 +52,13 @@ class SpecialSupportedFilterKeys(str, Enum): MODEL_TYPE = "model_type" +class ProprietaryModelFilterIdentifiers(str, Enum): + """Enum class for proprietary model filter keys.""" + + PROPRIETARY = "proprietary" + MARKETPLACE = "marketplace" + + FILTER_OPERATOR_STRING_MAPPINGS = { FilterOperators.EQUALS: ["===", "==", "equals", "is"], FilterOperators.NOT_EQUALS: ["!==", "!=", "not equals", "is not"], diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index de1d06a2a8..99c09b2375 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -631,7 +631,7 @@ def deploy( and endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED ): raise ValueError( - "EndpointType.INFERENCE_COMPONENT_BASED is not supported for Proprietary models." + f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models." ) try: diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 41c03c0f62..485354e802 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -29,6 +29,7 @@ from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.filters import ( SPECIAL_SUPPORTED_FILTER_KEYS, + ProprietaryModelFilterIdentifiers, BooleanValues, Identity, SpecialSupportedFilterKeys, @@ -227,7 +228,6 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ - if (isinstance(filter, Constant) and filter.resolved_value == BooleanValues.TRUE) or ( isinstance(filter, str) and filter.lower() == BooleanValues.TRUE.lower() ): @@ -360,8 +360,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin key = model_filter.key all_keys.add(key) if model_filter.key == SpecialSupportedFilterKeys.MODEL_TYPE and model_filter.value in { - "marketplace", - "proprietary", + identifier.value for identifier in ProprietaryModelFilterIdentifiers }: model_filter.set_value(JumpStartModelType.PROPRIETARY.value) model_filters.add(model_filter) From 46ae2931df187dfb0f05c8dbf64bdeeae3a1652e Mon Sep 17 00:00:00 2001 From: Haotian An Date: Mon, 11 Mar 2024 17:07:50 +0000 Subject: [PATCH 28/30] put back doc utils for future refactoring --- doc/doc_utils/jumpstart_doc_utils.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 51857526b9..7d8f96f664 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -208,19 +208,18 @@ def get_model_source(url): def create_proprietary_model_table(): - marketpkace_content_intro = """ - .. list-table:: Available Proprietary Models - :widths: 50 20 20 20 20 - :header-rows: 1 - :class: datatable - - * - Model ID - - Fine Tunable? - - Supported Version - - Min SDK Version - - Source - - """ + proprietary_content_intro = [] + proprietary_content_intro.append("\n") + proprietary_content_intro.append(".. list-table:: Available Models\n") + proprietary_content_intro.append(" :widths: 50 20 20 20 20\n") + proprietary_content_intro.append(" :header-rows: 1\n") + proprietary_content_intro.append(" :class: datatable\n") + proprietary_content_intro.append("\n") + proprietary_content_intro.append(" * - Model ID\n") + proprietary_content_intro.append(" - Fine Tunable?\n") + proprietary_content_intro.append(" - Supported Version\n") + proprietary_content_intro.append(" - Min SDK Version\n") + proprietary_content_intro.append(" - Source\n") sdk_manifest = get_proprietary_sdk_manifest() sdk_manifest_top_versions_for_models = {} @@ -244,7 +243,7 @@ def create_proprietary_model_table(): proprietary_content_entries.append( " - `{} <{}>`__ |external-link|\n".format("Source", model_spec.get("url")) ) - return [marketpkace_content_intro] + proprietary_content_entries + ["\n"] + return proprietary_content_intro + proprietary_content_entries + ["\n"] def create_jumpstart_model_table(): From 5f053d578277022315b30f1075a85e5d9f5cd613 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Mon, 11 Mar 2024 17:44:20 +0000 Subject: [PATCH 29/30] add prop model title in doc --- doc/doc_utils/jumpstart_doc_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 7d8f96f664..458da694d5 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -210,7 +210,7 @@ def get_model_source(url): def create_proprietary_model_table(): proprietary_content_intro = [] proprietary_content_intro.append("\n") - proprietary_content_intro.append(".. list-table:: Available Models\n") + proprietary_content_intro.append(".. list-table:: Available Proprietary Models\n") proprietary_content_intro.append(" :widths: 50 20 20 20 20\n") proprietary_content_intro.append(" :header-rows: 1\n") proprietary_content_intro.append(" :class: datatable\n") From 9ec6f8e3c28172b87b307dfec0af99acb63f71e5 Mon Sep 17 00:00:00 2001 From: Haotian An Date: Mon, 11 Mar 2024 21:23:14 +0000 Subject: [PATCH 30/30] doc update --- src/sagemaker/jumpstart/cache.py | 2 +- src/sagemaker/jumpstart/enums.py | 6 ++++-- src/sagemaker/jumpstart/exceptions.py | 11 +++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index f337edb4c2..7682ab3817 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -319,7 +319,7 @@ def _get_open_weight_manifest_key_from_model_id( key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 ) -> JumpStartVersionedModelId: - """For open weights models, retrieve model manifest key for open source model. + """For open weights models, retrieve model manifest key for open weight model. Filters models list by supported versions. """ diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index 62de9fc3c3..afe0df0dfe 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -37,8 +37,10 @@ class ModelFramework(str, Enum): class JumpStartModelType(str, Enum): """Enum class for JumpStart model type. - Open source model refers to JumpStart owned community models. - Proprietary model refers to external provider owned Marketplace models. + OPEN_WEIGHTS: Publicly available models have open weights + and are onboarded and maintained by JumpStart. + PROPRIETARY: Proprietary models from third-party providers do not have open weights. + You must subscribe to proprietary models in AWS Marketplace before use. """ OPEN_WEIGHTS = "open_weights" diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index 4bc454c139..742a6b8d3f 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -71,7 +71,7 @@ def get_proprietary_model_subscription_msg( return ( f"INFO: Using proprietary model '{model_id}'. " - f"Please make sure to subscribe to the model on {subscription_link}" + f"To subscribe to this model in AWS Marketplace, see {subscription_link}" ) @@ -215,7 +215,7 @@ class MarketplaceModelSubscriptionError(ValueError): """Exception raised when trying to deploy a JumpStart Marketplace model. A caller is required to subscribe to the Marketplace product in order to deploy. - This exception is raised when a caller tries to deploy a JumpStart Marketplace + This exception is raised when a caller tries to deploy a JumpStart Marketplace model but the caller is not subscribed to the model. """ @@ -227,8 +227,11 @@ def __init__( if message: self.message = message else: - self.message = "You have not subscribed to this Marketplace model. " + self.message = ( + "To use a proprietary JumpStart model, " + "you must first subscribe to the model in AWS Marketplace. " + ) if model_subscription_link: - self.message += f"Please subscribe following this link {model_subscription_link}" + self.message += f"To subscribe to this model, see {model_subscription_link}" super().__init__(self.message)