Skip to content

Commit 032cb80

Browse files
committed
add hub and hubcontent support in retrieval function for jumpstart model cache
1 parent b3a72fd commit 032cb80

File tree

12 files changed

+270
-55
lines changed

12 files changed

+270
-55
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 64 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import datetime
1616
from difflib import get_close_matches
1717
import os
18-
from typing import List, Optional, Tuple, Union
18+
from typing import Any, Dict, List, Optional, Tuple, Union
1919
import json
2020
import boto3
2121
import botocore
@@ -29,6 +29,7 @@
2929
JUMPSTART_LOGGER,
3030
MODEL_ID_LIST_WEB_URL,
3131
)
32+
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
3233
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
3334
from sagemaker.jumpstart.parameters import (
3435
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
@@ -37,12 +38,13 @@
3738
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
3839
)
3940
from sagemaker.jumpstart.types import (
40-
JumpStartCachedS3ContentKey,
41-
JumpStartCachedS3ContentValue,
41+
JumpStartCachedContentKey,
42+
JumpStartCachedContentValue,
4243
JumpStartModelHeader,
4344
JumpStartModelSpecs,
4445
JumpStartS3FileType,
4546
JumpStartVersionedModelId,
47+
HubDataType,
4648
)
4749
from sagemaker.jumpstart import utils
4850
from sagemaker.utilities.cache import LRUCache
@@ -95,7 +97,7 @@ def __init__(
9597
"""
9698

9799
self._region = region
98-
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
100+
self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue](
99101
max_cache_items=max_s3_cache_items,
100102
expiration_horizon=s3_cache_expiration_horizon,
101103
retrieval_function=self._retrieval_function,
@@ -172,8 +174,8 @@ def _get_manifest_key_from_model_id_semantic_version(
172174

173175
model_id, version = key.model_id, key.version
174176

175-
manifest = self._s3_cache.get(
176-
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
177+
manifest = self._content_cache.get(
178+
JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
177179
)[0].formatted_content
178180

179181
sm_version = utils.get_sagemaker_version()
@@ -301,50 +303,65 @@ def _get_json_file_from_local_override(
301303

302304
def _retrieval_function(
303305
self,
304-
key: JumpStartCachedS3ContentKey,
305-
value: Optional[JumpStartCachedS3ContentValue],
306-
) -> JumpStartCachedS3ContentValue:
307-
"""Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``.
306+
key: JumpStartCachedContentKey,
307+
value: Optional[JumpStartCachedContentValue],
308+
) -> JumpStartCachedContentValue:
309+
"""Return s3 content given a data type and s3_key in ``JumpStartCachedContentKey``.
308310
309311
If a manifest file is being fetched, we only download the object if the md5 hash in
310312
``head_object`` does not match the current md5 hash for the stored value. This prevents
311313
unnecessarily downloading the full manifest when it hasn't changed.
312314
313315
Args:
314-
key (JumpStartCachedS3ContentKey): key for which to fetch s3 content.
316+
key (JumpStartCachedContentKey): key for which to fetch JumpStart content.
315317
value (Optional[JumpStartVersionedModelId]): Current value of old cached
316318
s3 content. This is used for the manifest file, so that it is only
317319
downloaded when its content changes.
318320
"""
319321

320-
file_type, s3_key = key.file_type, key.s3_key
322+
data_type, id_info = key.data_type, key.id_info
321323

322-
if file_type == JumpStartS3FileType.MANIFEST:
324+
if data_type == JumpStartS3FileType.MANIFEST:
323325
if value is not None and not self._is_local_metadata_mode():
324-
etag = self._get_json_md5_hash(s3_key)
326+
etag = self._get_json_md5_hash(id_info)
325327
if etag == value.md5_hash:
326328
return value
327-
formatted_body, etag = self._get_json_file(s3_key, file_type)
328-
return JumpStartCachedS3ContentValue(
329+
formatted_body, etag = self._get_json_file(id_info, data_type)
330+
return JumpStartCachedContentValue(
329331
formatted_content=utils.get_formatted_manifest(formatted_body),
330332
md5_hash=etag,
331333
)
332-
if file_type == JumpStartS3FileType.SPECS:
333-
formatted_body, _ = self._get_json_file(s3_key, file_type)
334+
if data_type == JumpStartS3FileType.SPECS:
335+
formatted_body, _ = self._get_json_file(id_info, data_type)
334336
model_specs = JumpStartModelSpecs(formatted_body)
335337
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
336-
return JumpStartCachedS3ContentValue(
338+
return JumpStartCachedContentValue(
337339
formatted_content=model_specs
338340
)
341+
if data_type == HubDataType.MODEL:
342+
hub_name, hub_region, model_id, model_version = utils.extract_info_from_hub_content_arn(id_info)
343+
hub = CuratedHub(hub_name=hub_name, hub_region=hub_region)
344+
hub_content = hub.describe_model(model_id=model_id, model_version=model_version)
345+
utils.emit_logs_based_on_model_specs(hub_content.content_document, self.get_region(), self._s3_client)
346+
model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True)
347+
return JumpStartCachedContentValue(
348+
formatted_content=model_specs
349+
)
350+
if data_type == HubDataType.HUB:
351+
hub_name, hub_region, _, _ = utils.extract_info_from_hub_content_arn(id_info)
352+
hub = CuratedHub(hub_name=hub_name, hub_region=hub_region)
353+
hub_info = hub.describe()
354+
return JumpStartCachedContentValue(formatted_content=hub_info)
339355
raise ValueError(
340-
f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}"
356+
f"Bad value for key '{key}': must be in",
357+
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubDataType.HUB, HubDataType.MODEL]}"
341358
)
342359

343360
def get_manifest(self) -> List[JumpStartModelHeader]:
344361
"""Return entire JumpStart models manifest."""
345362

346-
manifest_dict = self._s3_cache.get(
347-
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
363+
manifest_dict = self._content_cache.get(
364+
JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
348365
)[0].formatted_content
349366
manifest = list(manifest_dict.values()) # type: ignore
350367
return manifest
@@ -407,8 +424,8 @@ def _get_header_impl(
407424
JumpStartVersionedModelId(model_id, semantic_version_str)
408425
)[0]
409426

410-
manifest = self._s3_cache.get(
411-
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
427+
manifest = self._content_cache.get(
428+
JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
412429
)[0].formatted_content
413430
try:
414431
header = manifest[versioned_model_id] # type: ignore
@@ -430,8 +447,8 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
430447

431448
header = self.get_header(model_id, semantic_version_str)
432449
spec_key = header.spec_key
433-
specs, cache_hit = self._s3_cache.get(
434-
JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key)
450+
specs, cache_hit = self._content_cache.get(
451+
JumpStartCachedContentKey(JumpStartS3FileType.SPECS, spec_key)
435452
)
436453
if not cache_hit and "*" in semantic_version_str:
437454
JUMPSTART_LOGGER.warning(
@@ -442,8 +459,28 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
442459
)
443460
)
444461
return specs.formatted_content
462+
463+
def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
464+
"""Return JumpStart-compatible specs for a given Hub model
465+
466+
Args:
467+
hub_model_arn (str): Arn for the Hub model to get specs for
468+
"""
469+
470+
specs, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn))
471+
return specs.formatted_content
472+
473+
def get_hub(self, hub_arn: str) -> Dict[str, Any]:
474+
"""Return descriptive info for a given Hub
475+
476+
Args:
477+
hub_arn (str): Arn for the Hub to get info for
478+
"""
479+
480+
manifest, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn))
481+
return manifest.formatted_content
445482

446483
def clear(self) -> None:
447484
"""Clears the model ID/version and s3 cache."""
448-
self._s3_cache.clear()
485+
self._content_cache.clear()
449486
self._model_id_semantic_version_manifest_key_cache.clear()

src/sagemaker/jumpstart/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@
170170

171171
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
172172

173+
# works cross-partition
174+
HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$"
175+
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
176+
173177
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"
174178
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"
175179

src/sagemaker/jumpstart/curated_hub/__init__.py

Whitespace-only changes.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from botocore.config import Config
14+
15+
DEFAULT_CLIENT_CONFIG = Config(retries={"max_attempts": 10, "mode": "standard"})
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from typing import Optional, Dict, Any
14+
15+
import boto3
16+
17+
from sagemaker.session import Session
18+
19+
from sagemaker.jumpstart.curated_hub.constants import DEFAULT_CLIENT_CONFIG
20+
21+
22+
class CuratedHub:
23+
"""Class for creating and managing a curated JumpStart hub"""
24+
25+
def __init__(self, hub_name: str, region: str, session: Optional[Session]):
26+
self.hub_name = hub_name
27+
self.region = region
28+
self.session = session
29+
self._s3_client = self._get_s3_client()
30+
self._sm_session = session or Session()
31+
32+
def _get_s3_client(self) -> Any:
33+
"""Returns an S3 client."""
34+
return boto3.client("s3", region_name=self._region, config=DEFAULT_CLIENT_CONFIG)
35+
36+
def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]:
37+
"""Returns descriptive information about the Hub Model"""
38+
39+
hub_content = self._sm_session.describe_hub_content(
40+
model_name, "Model", self.hub_name, model_version
41+
)
42+
43+
# TODO: Parse HubContent
44+
# TODO: Parse HubContentDocument
45+
46+
return hub_content
47+
48+
def describe(self) -> Dict[str, Any]:
49+
"""Returns descriptive information about the Hub"""
50+
51+
hub_info = self._sm_session.describe_hub(hub_name=self.hub_name)
52+
53+
# TODO: Validations?
54+
55+
return hub_info

src/sagemaker/jumpstart/types.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,17 @@ class JumpStartS3FileType(str, Enum):
106106
SPECS = "specs"
107107

108108

109+
class HubDataType(str, Enum):
110+
"""Enum for Hub data storage objects."""
111+
112+
HUB = "hub"
113+
MODEL = "model"
114+
NOTEBOOK = "notebook"
115+
116+
117+
JumpStartContentDataType = Union[JumpStartS3FileType, HubDataType]
118+
119+
109120
class JumpStartLaunchedRegionInfo(JumpStartDataHolderType):
110121
"""Data class for launched region info."""
111122

@@ -767,13 +778,16 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
767778
"gated_bucket",
768779
]
769780

770-
def __init__(self, spec: Dict[str, Any]):
781+
def __init__(self, spec: Dict[str, Any], is_hub_content: bool = False):
771782
"""Initializes a JumpStartModelSpecs object from its json representation.
772783
773784
Args:
774785
spec (Dict[str, Any]): Dictionary representation of spec.
775786
"""
776-
self.from_json(spec)
787+
if is_hub_content:
788+
self.from_hub_content_doc(spec)
789+
else:
790+
self.from_json(spec)
777791

778792
def from_json(self, json_obj: Dict[str, Any]) -> None:
779793
"""Sets fields in object based on json of header.
@@ -895,6 +909,15 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
895909
else None
896910
)
897911

912+
def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None:
913+
"""Sets fields in object based on values in HubContentDocument
914+
915+
Args:
916+
hub_content_doc (Dict[str, any]): parsed HubContentDocument returned
917+
from SageMaker:DescribeHubContent
918+
"""
919+
# TODO: Implement
920+
898921
def to_json(self) -> Dict[str, Any]:
899922
"""Returns json representation of JumpStartModelSpecs object."""
900923
json_obj = {}
@@ -958,27 +981,27 @@ def __init__(
958981
self.version = version
959982

960983

961-
class JumpStartCachedS3ContentKey(JumpStartDataHolderType):
962-
"""Data class for the s3 cached content keys."""
984+
class JumpStartCachedContentKey(JumpStartDataHolderType):
985+
"""Data class for the cached content keys."""
963986

964-
__slots__ = ["file_type", "s3_key"]
987+
__slots__ = ["data_type", "id_info"]
965988

966989
def __init__(
967990
self,
968-
file_type: JumpStartS3FileType,
969-
s3_key: str,
991+
data_type: JumpStartContentDataType,
992+
id_info: str,
970993
) -> None:
971994
"""Instantiates JumpStartCachedS3ContentKey object.
972995
973996
Args:
974-
file_type (JumpStartS3FileType): JumpStart file type.
975-
s3_key (str): object key in s3.
997+
data_type (JumpStartContentDataType): JumpStart content data type.
998+
id_info (str): if S3Content, object key in s3. if HubContent, hub content arn.
976999
"""
977-
self.file_type = file_type
978-
self.s3_key = s3_key
1000+
self.data_type = data_type
1001+
self.id_info = id_info
9791002

9801003

981-
class JumpStartCachedS3ContentValue(JumpStartDataHolderType):
1004+
class JumpStartCachedContentValue(JumpStartDataHolderType):
9821005
"""Data class for the s3 cached content values."""
9831006

9841007
__slots__ = ["formatted_content", "md5_hash"]

src/sagemaker/jumpstart/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515
import logging
1616
import os
17+
import re
1718
from typing import Any, Dict, List, Optional, Tuple, Union
1819
from urllib.parse import urlparse
1920
import boto3
@@ -810,3 +811,26 @@ def get_jumpstart_model_id_version_from_resource_arn(
810811
model_version = model_version_from_tag
811812

812813
return model_id, model_version
814+
815+
816+
def extract_info_from_hub_content_arn(
817+
arn: str,
818+
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
819+
"""Extracts hub_name, content_name, and content_version from a HubContentArn"""
820+
821+
match = re.match(constants.HUB_MODEL_ARN_REGEX, arn)
822+
if match:
823+
hub_name = match.group(4)
824+
hub_region = match.group(2)
825+
content_name = match.group(5)
826+
content_version = match.group(6)
827+
828+
return hub_name, hub_region, content_name, content_version
829+
830+
match = re.match(constants.HUB_ARN_REGEX, arn)
831+
if match:
832+
hub_name = match.group(4)
833+
hub_region = match.group(2)
834+
return hub_name, hub_region, None, None
835+
836+
return None, None, None, None

tests/unit/sagemaker/jumpstart/curated_hub/__init__.py

Whitespace-only changes.

tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py

Whitespace-only changes.

0 commit comments

Comments
 (0)