Skip to content

Commit f4da2ad

Browse files
committed
refactor, add types for hub/hubcontent descriptions, add helpers.
1 parent 4922511 commit f4da2ad

File tree

6 files changed

+344
-224
lines changed

6 files changed

+344
-224
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
JUMPSTART_LOGGER,
3030
MODEL_ID_LIST_WEB_URL,
3131
)
32-
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
3332
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
3433
from sagemaker.jumpstart.parameters import (
3534
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
@@ -44,9 +43,12 @@
4443
JumpStartModelSpecs,
4544
JumpStartS3FileType,
4645
JumpStartVersionedModelId,
47-
HubDataType,
46+
HubContentType,
47+
HubDescription,
48+
HubContentDescription,
4849
)
4950
from sagemaker.jumpstart import utils
51+
from sagemaker.jumpstart.curated_hub import utils as hub_utils
5052
from sagemaker.utilities.cache import LRUCache
5153

5254

@@ -338,29 +340,33 @@ def _retrieval_function(
338340
return JumpStartCachedContentValue(
339341
formatted_content=model_specs
340342
)
341-
if data_type == HubDataType.MODEL:
343+
if data_type == HubContentType.MODEL:
342344
hub_name, region, model_name, model_version = utils.extract_info_from_hub_content_arn(
343345
id_info
344346
)
345-
hub = CuratedHub(hub_name=hub_name, region=region)
346-
hub_content = hub.describe_model(model_name=model_name, model_version=model_version)
347+
hub_model_description: HubContentDescription = hub_utils.describe_model(
348+
hub_name=hub_name,
349+
region=region,
350+
model_name=model_name,
351+
model_version=model_version
352+
)
353+
model_specs = JumpStartModelSpecs(hub_model_description, is_hub_content=True)
347354
utils.emit_logs_based_on_model_specs(
348-
hub_content.content_document,
355+
model_specs,
349356
self.get_region(),
350357
self._s3_client
351358
)
352-
model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True)
359+
# TODO: Parse HubContentDescription
353360
return JumpStartCachedContentValue(
354361
formatted_content=model_specs
355362
)
356-
if data_type == HubDataType.HUB:
363+
if data_type == HubContentType.HUB:
357364
hub_name, region, _, _ = utils.extract_info_from_hub_content_arn(id_info)
358-
hub = CuratedHub(hub_name=hub_name, region=region)
359-
hub_info = hub.describe()
360-
return JumpStartCachedContentValue(formatted_content=hub_info)
365+
hub_description: HubDescription = hub_utils.describe(hub_name=hub_name, region=region)
366+
return JumpStartCachedContentValue(formatted_content=hub_description)
361367
raise ValueError(
362368
f"Bad value for key '{key}': must be in",
363-
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubDataType.HUB, HubDataType.MODEL]}"
369+
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}"
364370
)
365371

366372
def get_manifest(self) -> List[JumpStartModelHeader]:
@@ -474,7 +480,7 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
474480
"""
475481

476482
details, _ = self._content_cache.get(
477-
JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)
483+
JumpStartCachedContentKey(HubContentType.MODEL, hub_model_arn)
478484
)
479485
return details.formatted_content
480486

@@ -485,7 +491,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]:
485491
hub_arn (str): Arn for the Hub to get info for
486492
"""
487493

488-
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn))
494+
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubContentType.HUB, hub_arn))
489495
return details.formatted_content
490496

491497
def clear(self) -> None:

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,32 @@
1313
"""This module provides the JumpStart Curated Hub class."""
1414
from __future__ import absolute_import
1515

16-
from typing import Optional, Dict, Any
16+
from typing import Any, Dict, Optional
1717
import boto3
1818
from sagemaker.session import Session
1919
from sagemaker.jumpstart.constants import (
2020
JUMPSTART_DEFAULT_REGION_NAME,
2121
)
2222

23-
from sagemaker.jumpstart.types import HubDataType
24-
import sagemaker.jumpstart.curated_hub.utils as hubutils
23+
from sagemaker.jumpstart.types import HubDescription, HubContentType, HubContentDescription
24+
import sagemaker.jumpstart.session_utils as session_utils
2525

2626

2727
class CuratedHub:
2828
"""Class for creating and managing a curated JumpStart hub"""
2929

3030
def __init__(
3131
self,
32-
name: str,
32+
hub_name: str,
3333
region: str = JUMPSTART_DEFAULT_REGION_NAME,
34-
session: Optional[Session] = None,
34+
sagemaker_session: Optional[Session] = None,
3535
):
36-
self.name = name
37-
if session.boto_region_name != region:
36+
self.hub_name = hub_name
37+
if sagemaker_session.boto_region_name != region:
3838
# TODO: Handle error
3939
pass
4040
self.region = region
41-
self._session = session or Session(boto3.Session(region_name=region))
41+
self._sagemaker_session = sagemaker_session or Session(boto3.Session(region_name=region))
4242

4343
def create(
4444
self,
@@ -50,32 +50,60 @@ def create(
5050
) -> Dict[str, str]:
5151
"""Creates a hub with the given description"""
5252

53-
return hubutils.create_hub(
54-
hub_name=self.name,
53+
bucket_name = session_utils.create_hub_bucket_if_it_does_not_exist(
54+
bucket_name, self._sagemaker_session
55+
)
56+
57+
return self._sagemaker_session.create_hub(
58+
hub_name=self.hub_name,
5559
hub_description=description,
5660
hub_display_name=display_name,
5761
hub_search_keywords=search_keywords,
5862
hub_bucket_name=bucket_name,
5963
tags=tags,
60-
sagemaker_session=self._session,
6164
)
6265

63-
def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]:
64-
"""Returns descriptive information about the Hub Model"""
66+
def describe(self) -> HubDescription:
67+
"""Returns descriptive information about the Hub"""
68+
69+
hub_description = self._sagemaker_session.describe_hub(hub_name=self.hub_name)
70+
71+
return HubDescription(hub_description)
72+
73+
def list_models(self, **kwargs) -> Dict[str, Any]:
74+
"""Lists the models in this Curated Hub
6575
66-
hub_content = hubutils.describe_hub_content(
67-
hub_name=self.name,
68-
content_name=model_name,
69-
content_type=HubDataType.MODEL,
70-
content_version=model_version,
71-
sagemaker_session=self._session,
76+
**kwargs: Passed to invocation of ``Session:list_hub_contents``.
77+
"""
78+
# TODO: Validate kwargs and fast-fail?
79+
80+
hub_content_summaries = self._sagemaker_session.list_hub_contents(
81+
hub_name=self.hub_name, hub_content_type=HubContentType.MODEL, **kwargs
7282
)
83+
# TODO: Handle pagination
84+
return hub_content_summaries
7385

74-
return hub_content
86+
def describe_model(self, model_name: str, model_version: str = "*") -> HubContentDescription:
87+
"""Returns descriptive information about the Hub Model"""
7588

76-
def describe(self) -> Dict[str, Any]:
77-
"""Returns descriptive information about the Hub"""
89+
hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
90+
hub_name=self.hub_name,
91+
hub_content_name=model_name,
92+
hub_content_version=model_version,
93+
hub_content_type=HubContentType.MODEL,
94+
)
95+
96+
return HubContentDescription(hub_content_description)
7897

79-
hub_info = hubutils.describe_hub(hub_name=self.name, sagemaker_session=self._session)
98+
def delete_model(self, model_name: str, model_version: str = "*") -> None:
99+
"""Deletes a model from this CuratedHub."""
100+
return self._sagemaker_session.delete_hub_content(
101+
hub_content_name=model_name,
102+
hub_content_version=model_version,
103+
hub_content_type=HubContentType.MODEL,
104+
hub_name=self.hub_name,
105+
)
80106

81-
return hub_info
107+
def delete(self) -> None:
108+
"""Deletes this Curated Hub"""
109+
return self._sagemaker_session.delete_hub(self.hub_name)

0 commit comments

Comments
 (0)