Skip to content

Commit fd24cab

Browse files
authored
feat: add hub and hubcontent support in retrieval function for jumpstart model cache (#4438)
1 parent b3a72fd commit fd24cab

File tree

11 files changed

+255
-57
lines changed

11 files changed

+255
-57
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 72 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import datetime
1616
from difflib import get_close_matches
1717
import os
18-
from typing import List, Optional, Tuple, Union
18+
from typing import Any, Dict, List, Optional, Tuple, Union
1919
import json
2020
import boto3
2121
import botocore
@@ -29,6 +29,7 @@
2929
JUMPSTART_LOGGER,
3030
MODEL_ID_LIST_WEB_URL,
3131
)
32+
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
3233
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
3334
from sagemaker.jumpstart.parameters import (
3435
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
@@ -37,12 +38,13 @@
3738
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
3839
)
3940
from sagemaker.jumpstart.types import (
40-
JumpStartCachedS3ContentKey,
41-
JumpStartCachedS3ContentValue,
41+
JumpStartCachedContentKey,
42+
JumpStartCachedContentValue,
4243
JumpStartModelHeader,
4344
JumpStartModelSpecs,
4445
JumpStartS3FileType,
4546
JumpStartVersionedModelId,
47+
HubDataType,
4648
)
4749
from sagemaker.jumpstart import utils
4850
from sagemaker.utilities.cache import LRUCache
@@ -95,7 +97,7 @@ def __init__(
9597
"""
9698

9799
self._region = region
98-
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
100+
self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue](
99101
max_cache_items=max_s3_cache_items,
100102
expiration_horizon=s3_cache_expiration_horizon,
101103
retrieval_function=self._retrieval_function,
@@ -172,8 +174,8 @@ def _get_manifest_key_from_model_id_semantic_version(
172174

173175
model_id, version = key.model_id, key.version
174176

175-
manifest = self._s3_cache.get(
176-
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
177+
manifest = self._content_cache.get(
178+
JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
177179
)[0].formatted_content
178180

179181
sm_version = utils.get_sagemaker_version()
@@ -301,50 +303,71 @@ def _get_json_file_from_local_override(
301303

302304
def _retrieval_function(
303305
self,
304-
key: JumpStartCachedS3ContentKey,
305-
value: Optional[JumpStartCachedS3ContentValue],
306-
) -> JumpStartCachedS3ContentValue:
307-
"""Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``.
306+
key: JumpStartCachedContentKey,
307+
value: Optional[JumpStartCachedContentValue],
308+
) -> JumpStartCachedContentValue:
309+
"""Return s3 content given a data type and s3_key in ``JumpStartCachedContentKey``.
308310
309311
If a manifest file is being fetched, we only download the object if the md5 hash in
310312
``head_object`` does not match the current md5 hash for the stored value. This prevents
311313
unnecessarily downloading the full manifest when it hasn't changed.
312314
313315
Args:
314-
key (JumpStartCachedS3ContentKey): key for which to fetch s3 content.
316+
key (JumpStartCachedContentKey): key for which to fetch JumpStart content.
315317
value (Optional[JumpStartVersionedModelId]): Current value of old cached
316318
s3 content. This is used for the manifest file, so that it is only
317319
downloaded when its content changes.
318320
"""
319321

320-
file_type, s3_key = key.file_type, key.s3_key
322+
data_type, id_info = key.data_type, key.id_info
321323

322-
if file_type == JumpStartS3FileType.MANIFEST:
324+
if data_type == JumpStartS3FileType.MANIFEST:
323325
if value is not None and not self._is_local_metadata_mode():
324-
etag = self._get_json_md5_hash(s3_key)
326+
etag = self._get_json_md5_hash(id_info)
325327
if etag == value.md5_hash:
326328
return value
327-
formatted_body, etag = self._get_json_file(s3_key, file_type)
328-
return JumpStartCachedS3ContentValue(
329+
formatted_body, etag = self._get_json_file(id_info, data_type)
330+
return JumpStartCachedContentValue(
329331
formatted_content=utils.get_formatted_manifest(formatted_body),
330332
md5_hash=etag,
331333
)
332-
if file_type == JumpStartS3FileType.SPECS:
333-
formatted_body, _ = self._get_json_file(s3_key, file_type)
334+
if data_type == JumpStartS3FileType.SPECS:
335+
formatted_body, _ = self._get_json_file(id_info, data_type)
334336
model_specs = JumpStartModelSpecs(formatted_body)
335337
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
336-
return JumpStartCachedS3ContentValue(
338+
return JumpStartCachedContentValue(
337339
formatted_content=model_specs
338340
)
341+
if data_type == HubDataType.MODEL:
342+
hub_name, region, model_name, model_version = utils.extract_info_from_hub_content_arn(
343+
id_info
344+
)
345+
hub = CuratedHub(hub_name=hub_name, region=region)
346+
hub_content = hub.describe_model(model_name=model_name, model_version=model_version)
347+
utils.emit_logs_based_on_model_specs(
348+
hub_content.content_document,
349+
self.get_region(),
350+
self._s3_client
351+
)
352+
model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True)
353+
return JumpStartCachedContentValue(
354+
formatted_content=model_specs
355+
)
356+
if data_type == HubDataType.HUB:
357+
hub_name, region, _, _ = utils.extract_info_from_hub_content_arn(id_info)
358+
hub = CuratedHub(hub_name=hub_name, region=region)
359+
hub_info = hub.describe()
360+
return JumpStartCachedContentValue(formatted_content=hub_info)
339361
raise ValueError(
340-
f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}"
362+
f"Bad value for key '{key}': must be in",
363+
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubDataType.HUB, HubDataType.MODEL]}"
341364
)
342365

343366
def get_manifest(self) -> List[JumpStartModelHeader]:
344367
"""Return entire JumpStart models manifest."""
345368

346-
manifest_dict = self._s3_cache.get(
347-
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
369+
manifest_dict = self._content_cache.get(
370+
JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
348371
)[0].formatted_content
349372
manifest = list(manifest_dict.values()) # type: ignore
350373
return manifest
@@ -407,8 +430,8 @@ def _get_header_impl(
407430
JumpStartVersionedModelId(model_id, semantic_version_str)
408431
)[0]
409432

410-
manifest = self._s3_cache.get(
411-
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
433+
manifest = self._content_cache.get(
434+
JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
412435
)[0].formatted_content
413436
try:
414437
header = manifest[versioned_model_id] # type: ignore
@@ -430,8 +453,8 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
430453

431454
header = self.get_header(model_id, semantic_version_str)
432455
spec_key = header.spec_key
433-
specs, cache_hit = self._s3_cache.get(
434-
JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key)
456+
specs, cache_hit = self._content_cache.get(
457+
JumpStartCachedContentKey(JumpStartS3FileType.SPECS, spec_key)
435458
)
436459
if not cache_hit and "*" in semantic_version_str:
437460
JUMPSTART_LOGGER.warning(
@@ -443,7 +466,29 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
443466
)
444467
return specs.formatted_content
445468

469+
def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
470+
"""Return JumpStart-compatible specs for a given Hub model
471+
472+
Args:
473+
hub_model_arn (str): Arn for the Hub model to get specs for
474+
"""
475+
476+
details, _ = self._content_cache.get(
477+
JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)
478+
)
479+
return details.formatted_content
480+
481+
def get_hub(self, hub_arn: str) -> Dict[str, Any]:
482+
"""Return descriptive info for a given Hub
483+
484+
Args:
485+
hub_arn (str): Arn for the Hub to get info for
486+
"""
487+
488+
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn))
489+
return details.formatted_content
490+
446491
def clear(self) -> None:
447492
"""Clears the model ID/version and s3 cache."""
448-
self._s3_cache.clear()
493+
self._content_cache.clear()
449494
self._model_id_semantic_version_manifest_key_cache.clear()

src/sagemaker/jumpstart/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@
170170

171171
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
172172

173+
# works cross-partition
174+
HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$"
175+
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
176+
173177
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"
174178
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"
175179

src/sagemaker/jumpstart/curated_hub/__init__.py

Whitespace-only changes.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module provides the JumpStart Curated Hub class."""
14+
from __future__ import absolute_import
15+
16+
from typing import Optional, Dict, Any
17+
18+
from sagemaker.session import Session
19+
20+
21+
class CuratedHub:
22+
"""Class for creating and managing a curated JumpStart hub"""
23+
24+
def __init__(self, hub_name: str, region: str, session: Optional[Session] = None):
25+
self.hub_name = hub_name
26+
self.region = region
27+
self.session = session
28+
self._sm_session = session or Session()
29+
30+
def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]:
31+
"""Returns descriptive information about the Hub Model"""
32+
33+
hub_content = self._sm_session.describe_hub_content(
34+
model_name, "Model", self.hub_name, model_version
35+
)
36+
37+
# TODO: Parse HubContent
38+
# TODO: Parse HubContentDocument
39+
40+
return hub_content
41+
42+
def describe(self) -> Dict[str, Any]:
43+
"""Returns descriptive information about the Hub"""
44+
45+
hub_info = self._sm_session.describe_hub(hub_name=self.hub_name)
46+
47+
# TODO: Validations?
48+
49+
return hub_info

src/sagemaker/jumpstart/types.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,17 @@ class JumpStartS3FileType(str, Enum):
106106
SPECS = "specs"
107107

108108

109+
class HubDataType(str, Enum):
110+
"""Enum for Hub data storage objects."""
111+
112+
HUB = "hub"
113+
MODEL = "model"
114+
NOTEBOOK = "notebook"
115+
116+
117+
JumpStartContentDataType = Union[JumpStartS3FileType, HubDataType]
118+
119+
109120
class JumpStartLaunchedRegionInfo(JumpStartDataHolderType):
110121
"""Data class for launched region info."""
111122

@@ -767,13 +778,16 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
767778
"gated_bucket",
768779
]
769780

770-
def __init__(self, spec: Dict[str, Any]):
781+
def __init__(self, spec: Dict[str, Any], is_hub_content: bool = False):
771782
"""Initializes a JumpStartModelSpecs object from its json representation.
772783
773784
Args:
774785
spec (Dict[str, Any]): Dictionary representation of spec.
775786
"""
776-
self.from_json(spec)
787+
if is_hub_content:
788+
self.from_hub_content_doc(spec)
789+
else:
790+
self.from_json(spec)
777791

778792
def from_json(self, json_obj: Dict[str, Any]) -> None:
779793
"""Sets fields in object based on json of header.
@@ -895,6 +909,15 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
895909
else None
896910
)
897911

912+
def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None:
913+
"""Sets fields in object based on values in HubContentDocument
914+
915+
Args:
916+
hub_content_doc (Dict[str, any]): parsed HubContentDocument returned
917+
from SageMaker:DescribeHubContent
918+
"""
919+
# TODO: Implement
920+
898921
def to_json(self) -> Dict[str, Any]:
899922
"""Returns json representation of JumpStartModelSpecs object."""
900923
json_obj = {}
@@ -958,27 +981,27 @@ def __init__(
958981
self.version = version
959982

960983

961-
class JumpStartCachedS3ContentKey(JumpStartDataHolderType):
962-
"""Data class for the s3 cached content keys."""
984+
class JumpStartCachedContentKey(JumpStartDataHolderType):
985+
"""Data class for the cached content keys."""
963986

964-
__slots__ = ["file_type", "s3_key"]
987+
__slots__ = ["data_type", "id_info"]
965988

966989
def __init__(
967990
self,
968-
file_type: JumpStartS3FileType,
969-
s3_key: str,
991+
data_type: JumpStartContentDataType,
992+
id_info: str,
970993
) -> None:
971-
"""Instantiates JumpStartCachedS3ContentKey object.
994+
"""Instantiates JumpStartCachedContentKey object.
972995
973996
Args:
974-
file_type (JumpStartS3FileType): JumpStart file type.
975-
s3_key (str): object key in s3.
997+
data_type (JumpStartContentDataType): JumpStart content data type.
998+
id_info (str): if S3Content, object key in s3. if HubContent, hub content arn.
976999
"""
977-
self.file_type = file_type
978-
self.s3_key = s3_key
1000+
self.data_type = data_type
1001+
self.id_info = id_info
9791002

9801003

981-
class JumpStartCachedS3ContentValue(JumpStartDataHolderType):
1004+
class JumpStartCachedContentValue(JumpStartDataHolderType):
9821005
"""Data class for the s3 cached content values."""
9831006

9841007
__slots__ = ["formatted_content", "md5_hash"]
@@ -991,7 +1014,7 @@ def __init__(
9911014
],
9921015
md5_hash: Optional[str] = None,
9931016
) -> None:
994-
"""Instantiates JumpStartCachedS3ContentValue object.
1017+
"""Instantiates JumpStartCachedContentValue object.
9951018
9961019
Args:
9971020
formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader],

src/sagemaker/jumpstart/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515
import logging
1616
import os
17+
import re
1718
from typing import Any, Dict, List, Optional, Tuple, Union
1819
from urllib.parse import urlparse
1920
import boto3
@@ -810,3 +811,26 @@ def get_jumpstart_model_id_version_from_resource_arn(
810811
model_version = model_version_from_tag
811812

812813
return model_id, model_version
814+
815+
816+
def extract_info_from_hub_content_arn(
817+
arn: str,
818+
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
819+
"""Extracts hub_name, content_name, and content_version from a HubContentArn"""
820+
821+
match = re.match(constants.HUB_MODEL_ARN_REGEX, arn)
822+
if match:
823+
hub_name = match.group(4)
824+
hub_region = match.group(2)
825+
content_name = match.group(5)
826+
content_version = match.group(6)
827+
828+
return hub_name, hub_region, content_name, content_version
829+
830+
match = re.match(constants.HUB_ARN_REGEX, arn)
831+
if match:
832+
hub_name = match.group(4)
833+
hub_region = match.group(2)
834+
return hub_name, hub_region, None, None
835+
836+
return None, None, None, None

tests/unit/sagemaker/jumpstart/curated_hub/__init__.py

Whitespace-only changes.

tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py

Whitespace-only changes.

0 commit comments

Comments
 (0)