Skip to content

Commit a434837

Browse files
committed
feat: add hub and hubcontent support in retrieval function for jumpstart model cache (aws#4438)
1 parent ef7b9f4 commit a434837

File tree

5 files changed

+34
-1
lines changed

5 files changed

+34
-1
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
JumpStartVersionedModelId,
5858
HubType,
5959
HubContentType,
60+
HubDataType,
6061
)
6162
from sagemaker.jumpstart.curated_hub import utils as hub_utils
6263
from sagemaker.jumpstart.curated_hub.interfaces import (

src/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@
172172
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
173173
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"
174174

175-
HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$"
175+
# works cross-partition
176+
HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$"
176177
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
177178

178179
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"

src/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import re
1818
from typing import Any, Dict, List, Set, Optional, Tuple, Union
19+
import re
1920
from urllib.parse import urlparse
2021
import boto3
2122
from packaging.version import Version

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,35 @@ def test_mime_type_enum_from_str():
12061206
assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type
12071207

12081208

1209+
def test_extract_info_from_hub_content_arn():
1210+
model_arn = (
1211+
"arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2"
1212+
)
1213+
assert utils.extract_info_from_hub_content_arn(model_arn) == (
1214+
"MockHub",
1215+
"us-west-2",
1216+
"my-mock-model",
1217+
"1.0.2",
1218+
)
1219+
1220+
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub"
1221+
assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None)
1222+
1223+
invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123"
1224+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1225+
1226+
invalid_arn = "nonsense-string"
1227+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1228+
1229+
invalid_arn = ""
1230+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1231+
1232+
invalid_arn = (
1233+
"arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0"
1234+
)
1235+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1236+
1237+
12091238
class TestIsValidModelId(TestCase):
12101239
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest")
12111240
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs")

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
JUMPSTART_REGION_NAME_SET,
2323
)
2424
from sagemaker.jumpstart.types import (
25+
HubDataType,
2526
JumpStartCachedContentKey,
2627
JumpStartCachedContentValue,
2728
JumpStartModelSpecs,

0 commit comments

Comments
 (0)