diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 9dc505a2ff..6cbc5b30cb 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -21,6 +21,8 @@ import botocore from packaging.version import Version from packaging.specifiers import SpecifierSet, InvalidSpecifier +from sagemaker.session import Session +from sagemaker.utilities.cache import LRUCache from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, @@ -28,9 +30,8 @@ JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, MODEL_ID_LIST_WEB_URL, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub -from sagemaker.jumpstart.curated_hub.utils import get_info_from_hub_resource_arn from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, @@ -38,6 +39,7 @@ JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, ) +from sagemaker.jumpstart import utils from sagemaker.jumpstart.types import ( JumpStartCachedContentKey, JumpStartCachedContentValue, @@ -45,10 +47,12 @@ JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, + DescribeHubResponse, + DescribeHubContentsResponse, + HubType, HubContentType, ) -from sagemaker.jumpstart import utils -from sagemaker.utilities.cache import LRUCache +from sagemaker.jumpstart.curated_hub import utils as hub_utils class JumpStartModelsCache: @@ -74,6 +78,7 @@ def __init__( s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, s3_client: Optional[boto3.client] = None, + sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> None: # fmt: on """Initialize a ``JumpStartModelsCache`` instance. @@ -95,6 +100,8 @@ def __init__( s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache. Default: None (no config). s3_client (Optional[boto3.client]): s3 client to use. Default: None. + sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object, + used for SageMaker interactions. Default: Session in region associated with boto3 session. """ self._region = region @@ -121,6 +128,7 @@ def __init__( if s3_client_config else boto3.client("s3", region_name=self._region) ) + self._sagemaker_session = sagemaker_session def set_region(self, region: str) -> None: """Set region for cache. Clears cache after new region is set.""" @@ -340,32 +348,34 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubContentType.MODEL: - info = get_info_from_hub_resource_arn( + hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn( id_info ) - hub = CuratedHub(hub_name=info.hub_name, region=info.region) - hub_content = hub.describe_model( - model_name=info.hub_content_name, model_version=info.hub_content_version + hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=data_type ) + + model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True) + utils.emit_logs_based_on_model_specs( - hub_content.content_document, + model_specs, self.get_region(), self._s3_client ) - model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True) return JumpStartCachedContentValue( formatted_content=model_specs ) - if data_type == HubContentType.HUB: - info = get_info_from_hub_resource_arn( - id_info - ) - hub = CuratedHub(hub_name=info.hub_name, region=info.region) - hub_info = hub.describe() - return JumpStartCachedContentValue(formatted_content=hub_info) + if data_type == HubType.HUB: + hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info) + response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name) + hub_description = DescribeHubResponse(response) + return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description)) raise ValueError( - f"Bad value for key '{key}': must be in", - f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}" + f"Bad value for key '{key}': must be in ", + f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}" ) def get_manifest(self) -> List[JumpStartModelHeader]: @@ -490,7 +500,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]: hub_arn (str): Arn for the Hub to get info for """ - details, _ = self._content_cache.get(JumpStartCachedContentKey(HubContentType.HUB, hub_arn)) + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubType.HUB, hub_arn)) return details.formatted_content def clear(self) -> None: diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index b8885ff250..59a11df577 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -13,10 +13,16 @@ """This module provides the JumpStart Curated Hub class.""" from __future__ import absolute_import -from typing import Optional, Dict, Any -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from typing import Any, Dict, Optional from sagemaker.session import Session +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.types import ( + DescribeHubResponse, + DescribeHubContentsResponse, + HubContentType, +) +from sagemaker.jumpstart.curated_hub.utils import create_hub_bucket_if_it_does_not_exist class CuratedHub: @@ -25,30 +31,85 @@ class CuratedHub: def __init__( self, hub_name: str, - region: str, - session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ): + """Instantiates a SageMaker ``CuratedHub``. + + Args: + hub_name (str): The name of the Hub to create. + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. + """ self.hub_name = hub_name - self.region = region - self._sm_session = session + self.region = sagemaker_session.boto_region_name + self._sagemaker_session = sagemaker_session - def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]: - """Returns descriptive information about the Hub Model""" + def create( + self, + description: str, + display_name: Optional[str] = None, + search_keywords: Optional[str] = None, + bucket_name: Optional[str] = None, + tags: Optional[str] = None, + ) -> Dict[str, str]: + """Creates a hub with the given description""" - hub_content = self._sm_session.describe_hub_content( - model_name, "Model", self.hub_name, model_version + bucket_name = create_hub_bucket_if_it_does_not_exist(bucket_name, self._sagemaker_session) + + return self._sagemaker_session.create_hub( + hub_name=self.hub_name, + hub_description=description, + hub_display_name=display_name, + hub_search_keywords=search_keywords, + hub_bucket_name=bucket_name, + tags=tags, ) - # TODO: Parse HubContent - # TODO: Parse HubContentDocument + def describe(self) -> DescribeHubResponse: + """Returns descriptive information about the Hub""" - return hub_content + hub_description: DescribeHubResponse = self._sagemaker_session.describe_hub( + hub_name=self.hub_name + ) - def describe(self) -> Dict[str, Any]: - """Returns descriptive information about the Hub""" + return hub_description - hub_info = self._sm_session.describe_hub(hub_name=self.hub_name) + def list_models(self, **kwargs) -> Dict[str, Any]: + """Lists the models in this Curated Hub - # TODO: Validations? + **kwargs: Passed to invocation of ``Session:list_hub_contents``. + """ + # TODO: Validate kwargs and fast-fail? + + hub_content_summaries = self._sagemaker_session.list_hub_contents( + hub_name=self.hub_name, hub_content_type=HubContentType.MODEL, **kwargs + ) + # TODO: Handle pagination + return hub_content_summaries + + def describe_model( + self, model_name: str, model_version: str = "*" + ) -> DescribeHubContentsResponse: + """Returns descriptive information about the Hub Model""" + + hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=self.hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL, + ) + + return DescribeHubContentsResponse(hub_content_description) + + def delete_model(self, model_name: str, model_version: str = "*") -> None: + """Deletes a model from this CuratedHub.""" + return self._sagemaker_session.delete_hub_content( + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL, + hub_name=self.hub_name, + ) - return hub_info + def delete(self) -> None: + """Deletes this Curated Hub""" + return self._sagemaker_session.delete_hub(self.hub_name) diff --git a/src/sagemaker/jumpstart/curated_hub/types.py b/src/sagemaker/jumpstart/curated_hub/types.py deleted file mode 100644 index d400137905..0000000000 --- a/src/sagemaker/jumpstart/curated_hub/types.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module stores types related to SageMaker JumpStart CuratedHub.""" -from __future__ import absolute_import -from typing import Optional - -from sagemaker.jumpstart.types import JumpStartDataHolderType - - -class HubArnExtractedInfo(JumpStartDataHolderType): - """Data class for info extracted from Hub arn.""" - - __slots__ = [ - "partition", - "region", - "account_id", - "hub_name", - "hub_content_type", - "hub_content_name", - "hub_content_version", - ] - - def __init__( - self, - partition: str, - region: str, - account_id: str, - hub_name: str, - hub_content_type: Optional[str] = None, - hub_content_name: Optional[str] = None, - hub_content_version: Optional[str] = None, - ) -> None: - """Instantiates HubArnExtractedInfo object.""" - - self.partition = partition - self.region = region - self.account_id = account_id - self.hub_name = hub_name - self.hub_content_type = hub_content_type - self.hub_content_name = hub_content_name - self.hub_content_version = hub_content_version diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index 7758277ee1..ac01da45ca 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -14,12 +14,13 @@ from __future__ import absolute_import import re from typing import Optional -from sagemaker.jumpstart import constants - -from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo -from sagemaker.jumpstart.types import HubContentType from sagemaker.session import Session from sagemaker.utils import aws_partition +from sagemaker.jumpstart.types import ( + HubContentType, + HubArnExtractedInfo, +) +from sagemaker.jumpstart import constants def get_info_from_hub_resource_arn( @@ -109,3 +110,45 @@ def generate_hub_arn_for_init_kwargs( else: hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session) return hub_arn + + +def generate_default_hub_bucket_name( + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions. + + Returns: + str: The name of the default bucket. If the name was not explicitly specified through + the Session or sagemaker_config, the bucket will take the form: + ``sagemaker-hubs-{region}-{AWS account ID}``. + """ + + region: str = sagemaker_session.boto_region_name + account_id: str = sagemaker_session.account_id() + + # TODO: Validate and fast fail + + return f"sagemaker-hubs-{region}-{account_id}" + + +def create_hub_bucket_if_it_does_not_exist( + bucket_name: Optional[str] = None, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Creates the default SageMaker Hub bucket if it does not exist. + + Returns: + str: The name of the default bucket. Takes the form: + ``sagemaker-hubs-{region}-{AWS account ID}``. + """ + + region: str = sagemaker_session.boto_region_name + if bucket_name is None: + bucket_name: str = generate_default_hub_bucket_name(sagemaker_session) + + sagemaker_session._create_s3_bucket_if_it_does_not_exist( + bucket_name=bucket_name, + region=region, + ) + + return bucket_name diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index e511a052d1..93e1bc3bd0 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -15,11 +15,11 @@ from __future__ import absolute_import from typing import Optional, Tuple -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn from sagemaker.session import Session from sagemaker.utils import aws_partition +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn def get_model_id_version_from_endpoint( diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 99753a3763..9c457a5626 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -15,15 +15,14 @@ from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union +from sagemaker.session import Session from sagemaker.utils import get_instance_type_family, format_tags, Tags +from sagemaker.enums import EndpointType from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines - -from sagemaker.session import Session from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements -from sagemaker.enums import EndpointType class JumpStartDataHolderType: @@ -98,6 +97,25 @@ def __repr__(self) -> str: } return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}" + def to_json(self) -> Dict[str, Any]: + """Returns json representation of object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + else: + json_obj[att] = cur_val + return json_obj + class JumpStartS3FileType(str, Enum): """Type of files published in JumpStart S3 distribution buckets.""" @@ -106,15 +124,20 @@ class JumpStartS3FileType(str, Enum): SPECS = "specs" -class HubContentType(str, Enum): - """Enum for Hub data storage objects.""" +class HubType(str, Enum): + """Enum for Hub objects.""" HUB = "Hub" + + +class HubContentType(str, Enum): + """Enum for Hub content objects.""" + MODEL = "Model" NOTEBOOK = "Notebook" -JumpStartContentDataType = Union[JumpStartS3FileType, HubContentType] +JumpStartContentDataType = Union[JumpStartS3FileType, HubType, HubContentType] class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): @@ -918,25 +941,6 @@ def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None: """ # TODO: Implement - def to_json(self) -> Dict[str, Any]: - """Returns json representation of JumpStartModelSpecs object.""" - json_obj = {} - for att in self.__slots__: - if hasattr(self, att): - cur_val = getattr(self, att) - if issubclass(type(cur_val), JumpStartDataHolderType): - json_obj[att] = cur_val.to_json() - elif isinstance(cur_val, list): - json_obj[att] = [] - for obj in cur_val: - if issubclass(type(obj), JumpStartDataHolderType): - json_obj[att].append(obj.to_json()) - else: - json_obj[att].append(obj) - else: - json_obj[att] = cur_val - return json_obj - def supports_prepacked_inference(self) -> bool: """Returns True if the model has a prepacked inference artifact.""" return getattr(self, "hosting_prepacked_artifact_key", None) is not None @@ -1027,6 +1031,203 @@ def __init__( self.md5_hash = md5_hash +class HubArnExtractedInfo(JumpStartDataHolderType): + """Data class for info extracted from Hub arn.""" + + __slots__ = [ + "partition", + "region", + "account_id", + "hub_name", + "hub_content_type", + "hub_content_name", + "hub_content_version", + ] + + def __init__( + self, + partition: str, + region: str, + account_id: str, + hub_name: str, + hub_content_type: Optional[str] = None, + hub_content_name: Optional[str] = None, + hub_content_version: Optional[str] = None, + ) -> None: + """Instantiates HubArnExtractedInfo object.""" + + self.partition = partition + self.region = region + self.account_id = account_id + self.hub_name = hub_name + self.hub_content_type = hub_content_type + self.hub_content_name = hub_content_name + self.hub_content_version = hub_content_version + + +class HubContentDependency(JumpStartDataHolderType): + """Data class for any dependencies related to hub content. + + Content can be scripts, model artifacts, datasets, or notebooks. + """ + + __slots__ = ["dependency_copy_path", "dependency_origin_path"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubContentDependency object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.dependency_copy_path: Optional[str] = json_obj.get("dependency_copy_path", "") + self.dependency_origin_path: Optional[str] = json_obj.get("dependency_origin_path", "") + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of HubContentDependency object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + return json_obj + + +class DescribeHubContentsResponse(JumpStartDataHolderType): + """Data class for the Hub Content from session.describe_hub_contents()""" + + __slots__ = [ + "creation_time", + "document_schema_version", + "failure_reason", + "hub_arn", + "hub_content_arn", + "hub_content_dependencies", + "hub_content_description", + "hub_content_display_name", + "hub_content_document", + "hub_content_markdown", + "hub_content_name", + "hub_content_search_keywords", + "hub_content_status", + "hub_content_type", + "hub_content_version", + "hub_name", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates DescribeHubContentsResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.creation_time: int = int(json_obj["creation_time"]) + self.document_schema_version: str = json_obj["document_schema_version"] + self.failure_reason: str = json_obj["failure_reason"] + self.hub_arn: str = json_obj["hub_arn"] + self.hub_content_arn: str = json_obj["hub_content_arn"] + self.hub_content_dependencies: List[HubContentDependency] = [ + HubContentDependency(dep) for dep in json_obj["hub_content_dependencies"] + ] + self.hub_content_description: str = json_obj["hub_content_description"] + self.hub_content_display_name: str = json_obj["hub_content_display_name"] + self.hub_content_document: str = json_obj["hub_content_document"] + self.hub_content_markdown: str = json_obj["hub_content_markdown"] + self.hub_content_name: str = json_obj["hub_content_name"] + self.hub_content_search_keywords: str = json_obj["hub_content_search_keywords"] + self.hub_content_status: str = json_obj["hub_content_status"] + self.hub_content_type: HubContentType = json_obj["hub_content_type"] + self.hub_content_version: str = json_obj["hub_content_version"] + self.hub_name: str = json_obj["hub_name"] + + +class HubS3StorageConfig(JumpStartDataHolderType): + """Data class for any dependencies related to hub content. + + Includes scripts, model artifacts, datasets, or notebooks. + """ + + __slots__ = ["s3_output_path"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubS3StorageConfig object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.s3_output_path: Optional[str] = json_obj.get("s3_output_path", "") + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of HubS3StorageConfig object.""" + return {"s3_output_path": self.s3_output_path} + + +class DescribeHubResponse(JumpStartDataHolderType): + """Data class for the Hub from session.describe_hub()""" + + __slots__ = [ + "creation_time", + "failure_reason", + "hub_arn", + "hub_description", + "hub_display_name", + "hub_name", + "hub_search_keywords", + "hub_status", + "last_modified_time", + "s3_storage_config", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates DescribeHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + + self.creation_time: int = int(json_obj["creation_time"]) + self.failure_reason: str = json_obj["failure_reason"] + self.hub_arn: str = json_obj["hub_arn"] + self.hub_description: str = json_obj["hub_description"] + self.hub_display_name: str = json_obj["hub_display_name"] + self.hub_name: str = json_obj["hub_name"] + self.hub_search_keywords: List[str] = json_obj["hub_search_keywords"] + self.hub_status: str = json_obj["hub_status"] + self.last_modified_time: int = int(json_obj["last_modified_time"]) + self.s3_storage_config: HubS3StorageConfig = HubS3StorageConfig( + json_obj["s3_storage_config"] + ) + + class JumpStartKwargs(JumpStartDataHolderType): """Data class for JumpStart object kwargs.""" diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py index e69de29bb2..2448721520 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py @@ -0,0 +1,130 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from mock import Mock +from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub + +REGION = "us-east-1" +ACCOUNT_ID = "123456789123" +HUB_NAME = "mock-hub-name" + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name="boto_session") + sagemaker_session_mock = Mock( + name="sagemaker_session", boto_session=boto_mock, boto_region_name=REGION + ) + sagemaker_session_mock._client_config.user_agent = ( + "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" + ) + sagemaker_session_mock.account_id.return_value = ACCOUNT_ID + return sagemaker_session_mock + + +def test_instantiates(sagemaker_session): + hub = CuratedHub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + assert hub.hub_name == HUB_NAME + assert hub.region == "us-east-1" + assert hub._sagemaker_session == sagemaker_session + + +@pytest.mark.parametrize( + ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), + [ + pytest.param("MockHub1", "this is my sagemaker hub", None, None, None, None), + pytest.param( + "MockHub2", + "this is my sagemaker hub two", + None, + "DisplayMockHub2", + ["mock", "hub", "123"], + [{"Key": "tag-key-1", "Value": "tag-value-1"}], + ), + ], +) +def test_create_with_no_bucket_name( + sagemaker_session, + hub_name, + hub_description, + hub_bucket_name, + hub_display_name, + hub_search_keywords, + tags, +): + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + sagemaker_session.create_hub = Mock(return_value=create_hub) + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + request = { + "hub_name": hub_name, + "hub_description": hub_description, + "hub_bucket_name": "sagemaker-hubs-us-east-1-123456789123", + "hub_display_name": hub_display_name, + "hub_search_keywords": hub_search_keywords, + "tags": tags, + } + response = hub.create( + description=hub_description, + display_name=hub_display_name, + bucket_name=hub_bucket_name, + search_keywords=hub_search_keywords, + tags=tags, + ) + sagemaker_session.create_hub.assert_called_with(**request) + assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + + +@pytest.mark.parametrize( + ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), + [ + pytest.param("MockHub1", "this is my sagemaker hub", "mock-bucket-123", None, None, None), + pytest.param( + "MockHub2", + "this is my sagemaker hub two", + "mock-bucket-123", + "DisplayMockHub2", + ["mock", "hub", "123"], + [{"Key": "tag-key-1", "Value": "tag-value-1"}], + ), + ], +) +def test_create_with_bucket_name( + sagemaker_session, + hub_name, + hub_description, + hub_bucket_name, + hub_display_name, + hub_search_keywords, + tags, +): + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + sagemaker_session.create_hub = Mock(return_value=create_hub) + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + request = { + "hub_name": hub_name, + "hub_description": hub_description, + "hub_bucket_name": hub_bucket_name, + "hub_display_name": hub_display_name, + "hub_search_keywords": hub_search_keywords, + "tags": tags, + } + response = hub.create( + description=hub_description, + display_name=hub_display_name, + bucket_name=hub_bucket_name, + search_keywords=hub_search_keywords, + tags=tags, + ) + sagemaker_session.create_hub.assert_called_with(**request) + assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index 892a2ed980..59a0a8f958 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -11,11 +11,11 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import + from unittest.mock import Mock +from sagemaker.jumpstart.types import HubArnExtractedInfo from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME - from sagemaker.jumpstart.curated_hub import utils -from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo def test_get_info_from_hub_resource_arn(): @@ -139,4 +139,36 @@ def test_generate_hub_arn_for_init_kwargs(): utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn ) + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session) + == hub_arn + ) + + +def test_generate_default_hub_bucket_name(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.boto_region_name = "us-east-1" + + assert ( + utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session) + == "sagemaker-hubs-us-east-1-123456789123" + ) + + +def test_create_hub_bucket_if_it_does_not_exist(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.client("sts").get_caller_identity.return_value = { + "Account": "123456789123" + } + mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None + mock_sagemaker_session.boto_region_name = "us-east-1" + bucket_name = "sagemaker-hubs-us-east-1-123456789123" + created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( + sagemaker_session=mock_sagemaker_session + ) + + mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() + assert created_hub_bucket_name == bucket_name assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 69c8659148..423dbf5e02 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -22,7 +22,7 @@ from mock.mock import MagicMock import pytest from mock import patch - +from sagemaker.session_settings import SessionSettings from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, @@ -45,6 +45,30 @@ from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +REGION = "us-east-1" +REGION2 = "us-east-2" +ACCOUNT_ID = "123456789123" + + +@pytest.fixture() +def sagemaker_session(): + mocked_boto_session = Mock(name="boto_session") + mocked_s3_client = Mock(name="s3_client") + mocked_sagemaker_session = Mock( + name="sagemaker_session", + boto_session=mocked_boto_session, + s3_client=mocked_s3_client, + boto_region_name=REGION, + config=None, + ) + mocked_sagemaker_session.sagemaker_config = {} + mocked_sagemaker_session._client_config.user_agent = ( + "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" + ) + mocked_sagemaker_session.account_id.return_value = ACCOUNT_ID + return mocked_sagemaker_session + + @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_cache_get_header(): @@ -252,14 +276,14 @@ def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client): @patch("boto3.client") def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client): cache = JumpStartModelsCache( - s3_bucket_name="some_bucket", region="some_region", manifest_file_s3_key="some_key" + s3_bucket_name="some_bucket", region=REGION, manifest_file_s3_key="some_key" ) cache.clear = MagicMock() cache.set_s3_bucket_name("some_bucket") cache.clear.assert_not_called() cache.clear.reset_mock() - cache.set_region("some_region") + cache.set_region(REGION) cache.clear.assert_not_called() cache.clear.reset_mock() cache.set_manifest_file_s3_key("some_key") @@ -270,7 +294,7 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client): cache.set_s3_bucket_name("some_bucket1") cache.clear.assert_called_once() cache.clear.reset_mock() - cache.set_region("some_region1") + cache.set_region(REGION2) cache.clear.assert_called_once() cache.clear.reset_mock() cache.set_manifest_file_s3_key("some_key1") @@ -399,7 +423,6 @@ def test_jumpstart_cache_handles_boto3_client_errors(): def test_jumpstart_cache_accepts_input_parameters(): - region = "us-east-1" max_s3_cache_items = 1 s3_cache_expiration_horizon = datetime.timedelta(weeks=2) max_semantic_version_cache_items = 3 @@ -408,7 +431,7 @@ def test_jumpstart_cache_accepts_input_parameters(): manifest_file_key = "some_s3_key" cache = JumpStartModelsCache( - region=region, + region=REGION, max_s3_cache_items=max_s3_cache_items, s3_cache_expiration_horizon=s3_cache_expiration_horizon, max_semantic_version_cache_items=max_semantic_version_cache_items, @@ -418,7 +441,7 @@ def test_jumpstart_cache_accepts_input_parameters(): ) assert cache.get_manifest_file_s3_key() == manifest_file_key - assert cache.get_region() == region + assert cache.get_region() == REGION assert cache.get_bucket() == bucket assert cache._content_cache._max_cache_items == max_s3_cache_items assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon @@ -741,7 +764,10 @@ def test_jumpstart_cache_get_specs(): @patch("sagemaker.jumpstart.cache.os.path.isdir") @patch("builtins.open") def test_jumpstart_local_metadata_override_header( - mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock + mocked_open: Mock, + mocked_is_dir: Mock, + mocked_get_json_file_and_etag_from_s3: Mock, + sagemaker_session: Mock, ): mocked_open.side_effect = mock_open(read_data=json.dumps(BASE_MANIFEST)) mocked_is_dir.return_value = True @@ -760,7 +786,7 @@ def test_jumpstart_local_metadata_override_header( mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root") assert mocked_is_dir.call_count == 2 - mocked_open.assert_called_once_with( + mocked_open.assert_called_with( "/some/directory/metadata/manifest/root/models_manifest.json", "r" ) mocked_get_json_file_and_etag_from_s3.assert_not_called() @@ -783,6 +809,7 @@ def test_jumpstart_local_metadata_override_specs( mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock, mock_emit_logs_based_on_model_specs, + sagemaker_session, ): mocked_open.side_effect = [ @@ -791,7 +818,9 @@ def test_jumpstart_local_metadata_override_specs( ] mocked_is_dir.return_value = True - cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + cache = JumpStartModelsCache( + s3_bucket_name="some_bucket", s3_client=Mock(), sagemaker_session=sagemaker_session + ) model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs( @@ -845,7 +874,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") assert mocked_is_dir.call_count == 2 - mocked_open.assert_not_called() + assert mocked_open.call_count == 2 mocked_get_json_file_and_etag_from_s3.assert_has_calls( calls=[ call("models_manifest.json"), diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 81526485f9..a809b32a24 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -22,13 +22,15 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - HubContentType, JumpStartCachedContentKey, JumpStartCachedContentValue, JumpStartModelSpecs, JumpStartS3FileType, JumpStartModelHeader, + HubType, + HubContentType, ) + from sagemaker.jumpstart.utils import get_formatted_manifest from tests.unit.sagemaker.jumpstart.constants import ( PROTOTYPICAL_MODEL_SPECS_DICT, @@ -193,7 +195,6 @@ def patched_retrieval_function( datatype, id_info = key.data_type, key.id_info if datatype == JumpStartS3FileType.MANIFEST: - return JumpStartCachedContentValue(formatted_content=get_formatted_manifest(BASE_MANIFEST)) if datatype == JumpStartS3FileType.SPECS: @@ -210,7 +211,7 @@ def patched_retrieval_function( ) # TODO: Implement - if datatype == HubContentType.HUB: + if datatype == HubType.HUB: return None raise ValueError(f"Bad value for filetype: {datatype}")