Skip to content

Commit 68d5c8a

Browse files
malav-shastriMalav Shastri
andcommitted
Implement CuratedHub APIs (aws#1449)
* Implement CuratedHub Admin APIs * making some parameters optional in create_hub_content_reference as per the API design * add describe_hub and list_hubs APIs * implement delete_hub API * Implement list_hub_contents API * create CuratedHub class and supported utils * implement list_models and address comments * Add unit tests * add describe_model function * cache retrieval for describeHubContent changes * fix curated hub class unit tests * add utils needed for curatedHub * Cache retrieval * implement get_hub_model_reference() * cleanup HUB type datatype * cleanup constants * rename list_public_models to list_jumpstart_service_hub_models * implement describe_model_reference * Rename CuratedHub to Hub * address nit * address nits and fix failing tests --------- Co-authored-by: Malav Shastri <[email protected]>
1 parent b96c98e commit 68d5c8a

File tree

19 files changed

+2500
-98
lines changed

19 files changed

+2500
-98
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 94 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import datetime
1616
from difflib import get_close_matches
1717
import os
18-
from typing import List, Optional, Tuple, Union
18+
from typing import Any, Dict, List, Optional, Tuple, Union
1919
import json
2020
import boto3
2121
import botocore
@@ -42,12 +42,19 @@
4242
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
4343
)
4444
from sagemaker.jumpstart.types import (
45-
JumpStartCachedS3ContentKey,
46-
JumpStartCachedS3ContentValue,
45+
JumpStartCachedContentKey,
46+
JumpStartCachedContentValue,
4747
JumpStartModelHeader,
4848
JumpStartModelSpecs,
4949
JumpStartS3FileType,
5050
JumpStartVersionedModelId,
51+
HubType,
52+
HubContentType
53+
)
54+
from sagemaker.jumpstart.hub import utils as hub_utils
55+
from sagemaker.jumpstart.hub.interfaces import (
56+
DescribeHubResponse,
57+
DescribeHubContentResponse,
5158
)
5259
from sagemaker.jumpstart.enums import JumpStartModelType
5360
from sagemaker.jumpstart import utils
@@ -104,7 +111,7 @@ def __init__(
104111
s3_bucket_name=s3_bucket_name, s3_client=s3_client
105112
)
106113

107-
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
114+
self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue](
108115
max_cache_items=max_s3_cache_items,
109116
expiration_horizon=s3_cache_expiration_horizon,
110117
retrieval_function=self._retrieval_function,
@@ -230,8 +237,8 @@ def _model_id_retrieval_function(
230237

231238
model_id, version = key.model_id, key.version
232239
sm_version = utils.get_sagemaker_version()
233-
manifest = self._s3_cache.get(
234-
JumpStartCachedS3ContentKey(
240+
manifest = self._content_cache.get(
241+
JumpStartCachedContentKey(
235242
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
236243
)
237244
)[0].formatted_content
@@ -392,53 +399,87 @@ def _get_json_file_from_local_override(
392399

393400
def _retrieval_function(
394401
self,
395-
key: JumpStartCachedS3ContentKey,
396-
value: Optional[JumpStartCachedS3ContentValue],
397-
) -> JumpStartCachedS3ContentValue:
398-
"""Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``.
402+
key: JumpStartCachedContentKey,
403+
value: Optional[JumpStartCachedContentValue],
404+
) -> JumpStartCachedContentValue:
405+
"""Return s3 content given a file type and s3_key in ``JumpStartCachedContentKey``.
399406
400407
If a manifest file is being fetched, we only download the object if the md5 hash in
401408
``head_object`` does not match the current md5 hash for the stored value. This prevents
402409
unnecessarily downloading the full manifest when it hasn't changed.
403410
404411
Args:
405-
key (JumpStartCachedS3ContentKey): key for which to fetch s3 content.
412+
key (JumpStartCachedContentKey): key for which to fetch s3 content.
406413
value (Optional[JumpStartVersionedModelId]): Current value of old cached
407414
s3 content. This is used for the manifest file, so that it is only
408415
downloaded when its content changes.
409416
"""
410417

411-
file_type, s3_key = key.file_type, key.s3_key
412-
if file_type in {
418+
data_type, id_info = key.data_type, key.id_info
419+
420+
if data_type in {
413421
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
414422
JumpStartS3FileType.PROPRIETARY_MANIFEST,
415423
}:
416424
if value is not None and not self._is_local_metadata_mode():
417-
etag = self._get_json_md5_hash(s3_key)
425+
etag = self._get_json_md5_hash(id_info)
418426
if etag == value.md5_hash:
419427
return value
420-
formatted_body, etag = self._get_json_file(s3_key, file_type)
421-
return JumpStartCachedS3ContentValue(
428+
formatted_body, etag = self._get_json_file(id_info, data_type)
429+
return JumpStartCachedContentValue(
422430
formatted_content=utils.get_formatted_manifest(formatted_body),
423431
md5_hash=etag,
424432
)
425-
if file_type in {
433+
if data_type in {
426434
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
427435
JumpStartS3FileType.PROPRIETARY_SPECS,
428436
}:
429-
formatted_body, _ = self._get_json_file(s3_key, file_type)
437+
formatted_body, _ = self._get_json_file(id_info, data_type)
430438
model_specs = JumpStartModelSpecs(formatted_body)
431439
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
432-
return JumpStartCachedS3ContentValue(formatted_content=model_specs)
433-
raise ValueError(self._file_type_error_msg(file_type))
440+
return JumpStartCachedContentValue(
441+
formatted_content=model_specs
442+
)
443+
444+
if data_type == HubContentType.NOTEBOOK:
445+
hub_name, _, notebook_name, notebook_version = hub_utils \
446+
.get_info_from_hub_resource_arn(id_info)
447+
response: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
448+
hub_name=hub_name,
449+
hub_content_name=notebook_name,
450+
hub_content_version=notebook_version,
451+
hub_content_type=data_type,
452+
)
453+
hub_notebook_description = DescribeHubContentResponse(response)
454+
return JumpStartCachedContentValue(formatted_content=hub_notebook_description)
455+
456+
if data_type in [HubContentType.MODEL, HubContentType.MODEL_REFERENCE]:
457+
hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn(
458+
id_info
459+
)
460+
hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
461+
hub_name=hub_name,
462+
hub_content_name=model_name,
463+
hub_content_version=model_version,
464+
hub_content_type=data_type,
465+
)
466+
467+
model_specs = make_model_specs_from_describe_hub_content_response(
468+
DescribeHubContentResponse(hub_model_description),
469+
)
470+
471+
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
472+
return JumpStartCachedContentValue(formatted_content=model_specs)
473+
474+
raise ValueError(self._file_type_error_msg(data_type))
434475

435476
def get_manifest(
436477
self,
437478
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
438479
) -> List[JumpStartModelHeader]:
439480
"""Return entire JumpStart models manifest."""
440-
manifest_dict = self._s3_cache.get(
441-
JumpStartCachedS3ContentKey(
481+
manifest_dict = self._content_cache.get(
482+
JumpStartCachedContentKey(
442483
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
443484
)
444485
)[0].formatted_content
@@ -525,8 +566,8 @@ def _get_header_impl(
525566
JumpStartVersionedModelId(model_id, semantic_version_str)
526567
)[0]
527568

528-
manifest = self._s3_cache.get(
529-
JumpStartCachedS3ContentKey(
569+
manifest = self._content_cache.get(
570+
JumpStartCachedContentKey(
530571
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
531572
)
532573
)[0].formatted_content
@@ -556,18 +597,44 @@ def get_specs(
556597
"""
557598
header = self.get_header(model_id, version_str, model_type)
558599
spec_key = header.spec_key
559-
specs, cache_hit = self._s3_cache.get(
560-
JumpStartCachedS3ContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key)
600+
specs, cache_hit = self._content_cache.get(
601+
JumpStartCachedContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key)
561602
)
562603

563604
if not cache_hit and "*" in version_str:
564605
JUMPSTART_LOGGER.warning(
565606
get_wildcard_model_version_msg(header.model_id, version_str, header.version)
566607
)
567608
return specs.formatted_content
609+
610+
def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
611+
"""Return JumpStart-compatible specs for a given Hub model
612+
613+
Args:
614+
hub_model_arn (str): Arn for the Hub model to get specs for
615+
"""
616+
617+
details, _ = self._content_cache.get(JumpStartCachedContentKey(
618+
HubContentType.MODEL,
619+
hub_model_arn,
620+
))
621+
return details.formatted_content
622+
623+
def get_hub_model_reference(self, hub_model_arn: str) -> JumpStartModelSpecs:
624+
"""Return JumpStart-compatible specs for a given Hub model reference
625+
626+
Args:
627+
hub_model_arn (str): Arn for the Hub model to get specs for
628+
"""
629+
630+
details, _ = self._content_cache.get(JumpStartCachedContentKey(
631+
HubContentType.MODEL_REFERENCE,
632+
hub_model_arn,
633+
))
634+
return details.formatted_content
568635

569636
def clear(self) -> None:
570637
"""Clears the model ID/version and s3 cache."""
571-
self._s3_cache.clear()
638+
self._content_cache.clear()
572639
self._open_weight_model_id_manifest_key_cache.clear()
573640
self._proprietary_model_id_manifest_key_cache.clear()

src/sagemaker/jumpstart/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@
188188
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
189189
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"
190190

191+
HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$"
192+
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
193+
191194
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"
192195
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"
193196

src/sagemaker/jumpstart/hub/__init__.py

Whitespace-only changes.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module stores constants related to SageMaker JumpStart CuratedHub."""
14+
from __future__ import absolute_import
15+
16+
JUMPSTART_MODEL_HUB_NAME = "JumpStartServiceHub"
17+
18+
LATEST_VERSION_WILDCARD = "*"

0 commit comments

Comments
 (0)