13
13
"""This module provides the JumpStart Curated Hub class."""
14
14
from __future__ import absolute_import
15
15
16
- from typing import Optional , Dict , Any
16
+ from typing import Any , Dict , Optional
17
17
import boto3
18
18
from sagemaker .session import Session
19
19
from sagemaker .jumpstart .constants import (
20
20
JUMPSTART_DEFAULT_REGION_NAME ,
21
21
)
22
22
23
- from sagemaker .jumpstart .types import HubDataType
24
- import sagemaker .jumpstart .curated_hub . utils as hubutils
23
+ from sagemaker .jumpstart .types import HubDescription , HubContentType , HubContentDescription
24
+ import sagemaker .jumpstart .session_utils as session_utils
25
25
26
26
27
27
class CuratedHub :
28
28
"""Class for creating and managing a curated JumpStart hub"""
29
29
30
30
def __init__ (
31
31
self ,
32
- name : str ,
32
+ hub_name : str ,
33
33
region : str = JUMPSTART_DEFAULT_REGION_NAME ,
34
- session : Optional [Session ] = None ,
34
+ sagemaker_session : Optional [Session ] = None ,
35
35
):
36
- self .name = name
37
- if session .boto_region_name != region :
36
+ self .hub_name = hub_name
37
+ if sagemaker_session .boto_region_name != region :
38
38
# TODO: Handle error
39
39
pass
40
40
self .region = region
41
- self ._session = session or Session (boto3 .Session (region_name = region ))
41
+ self ._sagemaker_session = sagemaker_session or Session (boto3 .Session (region_name = region ))
42
42
43
43
def create (
44
44
self ,
@@ -50,32 +50,60 @@ def create(
50
50
) -> Dict [str , str ]:
51
51
"""Creates a hub with the given description"""
52
52
53
- return hubutils .create_hub (
54
- hub_name = self .name ,
53
+ bucket_name = session_utils .create_hub_bucket_if_it_does_not_exist (
54
+ bucket_name , self ._sagemaker_session
55
+ )
56
+
57
+ return self ._sagemaker_session .create_hub (
58
+ hub_name = self .hub_name ,
55
59
hub_description = description ,
56
60
hub_display_name = display_name ,
57
61
hub_search_keywords = search_keywords ,
58
62
hub_bucket_name = bucket_name ,
59
63
tags = tags ,
60
- sagemaker_session = self ._session ,
61
64
)
62
65
63
- def describe_model (self , model_name : str , model_version : str = "*" ) -> Dict [str , Any ]:
64
- """Returns descriptive information about the Hub Model"""
66
+ def describe (self ) -> HubDescription :
67
+ """Returns descriptive information about the Hub"""
68
+
69
+ hub_description = self ._sagemaker_session .describe_hub (hub_name = self .hub_name )
70
+
71
+ return HubDescription (hub_description )
72
+
73
+ def list_models (self , ** kwargs ) -> Dict [str , Any ]:
74
+ """Lists the models in this Curated Hub
65
75
66
- hub_content = hubutils . describe_hub_content (
67
- hub_name = self . name ,
68
- content_name = model_name ,
69
- content_type = HubDataType . MODEL ,
70
- content_version = model_version ,
71
- sagemaker_session = self ._session ,
76
+ **kwargs: Passed to invocation of ``Session:list_hub_contents``.
77
+ """
78
+ # TODO: Validate kwargs and fast-fail?
79
+
80
+ hub_content_summaries = self . _sagemaker_session . list_hub_contents (
81
+ hub_name = self .hub_name , hub_content_type = HubContentType . MODEL , ** kwargs
72
82
)
83
+ # TODO: Handle pagination
84
+ return hub_content_summaries
73
85
74
- return hub_content
86
+ def describe_model (self , model_name : str , model_version : str = "*" ) -> HubContentDescription :
87
+ """Returns descriptive information about the Hub Model"""
75
88
76
- def describe (self ) -> Dict [str , Any ]:
77
- """Returns descriptive information about the Hub"""
89
+ hub_content_description : Dict [str , Any ] = self ._sagemaker_session .describe_hub_content (
90
+ hub_name = self .hub_name ,
91
+ hub_content_name = model_name ,
92
+ hub_content_version = model_version ,
93
+ hub_content_type = HubContentType .MODEL ,
94
+ )
95
+
96
+ return HubContentDescription (hub_content_description )
78
97
79
- hub_info = hubutils .describe_hub (hub_name = self .name , sagemaker_session = self ._session )
98
+ def delete_model (self , model_name : str , model_version : str = "*" ) -> None :
99
+ """Deletes a model from this CuratedHub."""
100
+ return self ._sagemaker_session .delete_hub_content (
101
+ hub_content_name = model_name ,
102
+ hub_content_version = model_version ,
103
+ hub_content_type = HubContentType .MODEL ,
104
+ hub_name = self .hub_name ,
105
+ )
80
106
81
- return hub_info
107
+ def delete (self ) -> None :
108
+ """Deletes this Curated Hub"""
109
+ return self ._sagemaker_session .delete_hub (self .hub_name )
0 commit comments