Skip to content

Commit 1c0b8b3

Browse files
jinyoung-limbencrabtree
authored andcommitted
feature: JumpStart CuratedHub class creation and function definitions (aws#4448)
1 parent a434837 commit 1c0b8b3

File tree

3 files changed

+177
-1
lines changed

3 files changed

+177
-1
lines changed

src/sagemaker/jumpstart/types.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,149 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
11561156
)
11571157
self.model_subscription_link = json_obj.get("model_subscription_link")
11581158

1159+
def from_describe_hub_content_response(self, response: DescribeHubContentResponse) -> None:
1160+
"""Sets fields in object based on values in HubContentDocument
1161+
1162+
Args:
1163+
hub_content_doc (Dict[str, any]): parsed HubContentDocument returned
1164+
from SageMaker:DescribeHubContent
1165+
"""
1166+
self.model_id: str = response.hub_content_name
1167+
self.version: str = response.hub_content_version
1168+
hub_content_document: HubModelDocument = response.hub_content_document
1169+
self.url: str = hub_content_document.url
1170+
self.min_sdk_version: str = hub_content_document.min_sdk_version
1171+
self.training_supported: bool = hub_content_document.training_supported
1172+
self.incremental_training_supported: bool = bool(
1173+
hub_content_document["IncrementalTrainingSupported"]
1174+
)
1175+
self.hosting_ecr_uri: Optional[str] = hub_content_document.hosting_ecr_uri
1176+
self._non_serializable_slots.append("hosting_ecr_specs")
1177+
1178+
hosting_artifact_bucket, hosting_artifact_key = parse_s3_url(
1179+
hub_content_document.hosting_artifact_uri
1180+
)
1181+
self.hosting_artifact_key: str = hosting_artifact_key
1182+
hosting_script_bucket, hosting_script_key = parse_s3_url(
1183+
hub_content_document.hosting_script_uri
1184+
)
1185+
self.hosting_script_key: str = hosting_script_key
1186+
self.inference_environment_variables = hub_content_document.inference_environment_variables
1187+
self.inference_vulnerable: bool = False
1188+
self.inference_dependencies: List[str] = hub_content_document.inference_dependencies
1189+
self.inference_vulnerabilities: List[str] = []
1190+
self.training_vulnerable: bool = False
1191+
self.training_dependencies: List[str] = hub_content_document.training_dependencies
1192+
self.training_vulnerabilities: List[str] = []
1193+
self.deprecated: bool = False
1194+
self.deprecated_message: Optional[str] = None
1195+
self.deprecate_warn_message: Optional[str] = None
1196+
self.usage_info_message: Optional[str] = None
1197+
self.default_inference_instance_type: Optional[
1198+
str
1199+
] = hub_content_document.default_inference_instance_type
1200+
self.default_training_instance_type: Optional[
1201+
str
1202+
] = hub_content_document.default_training_instance_type
1203+
self.supported_inference_instance_types: Optional[
1204+
List[str]
1205+
] = hub_content_document.supported_inference_instance_types
1206+
self.supported_training_instance_types: Optional[
1207+
List[str]
1208+
] = hub_content_document.supported_training_instance_types
1209+
self.dynamic_container_deployment_supported: Optional[
1210+
bool
1211+
] = hub_content_document.dynamic_container_deployment_supported
1212+
self.hosting_resource_requirements: Optional[
1213+
Dict[str, int]
1214+
] = hub_content_document.hosting_resource_requirements
1215+
self.metrics: Optional[List[Dict[str, str]]] = hub_content_document.training_metrics
1216+
self.training_prepacked_script_key: Optional[str] = None
1217+
if hub_content_document.training_prepacked_script_uri is not None:
1218+
training_prepacked_script_bucket, training_prepacked_script_key = parse_s3_url(
1219+
hub_content_document.training_prepacked_script_uri
1220+
)
1221+
self.training_prepacked_script_key = training_prepacked_script_key
1222+
1223+
self.hosting_prepacked_artifact_key: Optional[str] = None
1224+
if hub_content_document.hosting_prepacked_artifact_uri is not None:
1225+
hosting_prepacked_artifact_bucket, hosting_prepacked_artifact_key = parse_s3_url(
1226+
hub_content_document.hosting_prepacked_artifact_uri
1227+
)
1228+
self.hosting_prepacked_artifact_key = hosting_prepacked_artifact_key
1229+
1230+
self.fit_kwargs = get_model_spec_kwargs_from_hub_content_document(
1231+
ModelSpecKwargType.FIT, hub_content_document
1232+
)
1233+
self.model_kwargs = get_model_spec_kwargs_from_hub_content_document(
1234+
ModelSpecKwargType.MODEL, hub_content_document
1235+
)
1236+
self.deploy_kwargs = get_model_spec_kwargs_from_hub_content_document(
1237+
ModelSpecKwargType.DEPLOY, hub_content_document
1238+
)
1239+
self.estimator_kwargs = get_model_spec_kwargs_from_hub_content_document(
1240+
ModelSpecKwargType.ESTIMATOR, hub_content_document
1241+
)
1242+
1243+
self.predictor_specs: Optional[
1244+
JumpStartPredictorSpecs
1245+
] = hub_content_document.sage_maker_sdk_predictor_specifications
1246+
self.default_payloads: Optional[
1247+
Dict[str, JumpStartSerializablePayload]
1248+
] = hub_content_document.default_payloads
1249+
self.gated_bucket = hub_content_document.gated_bucket
1250+
self.inference_volume_size: Optional[int] = hub_content_document.inference_volume_size
1251+
self.inference_enable_network_isolation: bool = (
1252+
hub_content_document.inference_enable_network_isolation
1253+
)
1254+
self.resource_name_base: Optional[str] = hub_content_document.resource_name_base
1255+
1256+
self.hosting_eula_key: Optional[str] = None
1257+
if hub_content_document.hosting_eula_uri is not None:
1258+
hosting_eula_bucket, hosting_eula_key = parse_s3_url(
1259+
hub_content_document.hosting_eula_uri
1260+
)
1261+
self.hosting_eula_key = hosting_eula_key
1262+
1263+
self.hosting_model_package_arns: Optional[Dict] = None # TODO: Missing from shcema?
1264+
self.hosting_use_script_uri: bool = hub_content_document.hosting_use_script_uri
1265+
1266+
self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
1267+
JumpStartInstanceTypeVariants(hub_content_document.hosting_instance_type_variants)
1268+
if hub_content_document.hosting_instance_type_variants
1269+
else None
1270+
)
1271+
1272+
if self.training_supported:
1273+
self.training_ecr_uri: Optional[str] = hub_content_document.training_ecr_uri
1274+
self._non_serializable_slots.append("training_ecr_specs")
1275+
training_artifact_bucket, training_artifact_key = parse_s3_url(
1276+
hub_content_document.training_artifact_uri
1277+
)
1278+
self.training_artifact_key: str = training_artifact_key
1279+
training_script_bucket, training_script_key = parse_s3_url(
1280+
hub_content_document.training_script_uri
1281+
)
1282+
self.training_script_key: str = training_script_key
1283+
1284+
self.hyperparameters: List[
1285+
JumpStartHyperparameter
1286+
] = hub_content_document.hyperparameters
1287+
self.training_volume_size: Optional[int] = hub_content_document.training_volume_size
1288+
self.training_enable_network_isolation: bool = (
1289+
hub_content_document.training_enable_network_isolation
1290+
)
1291+
self.training_model_package_artifact_uris: Optional[
1292+
Dict
1293+
] = hub_content_document.training_model_package_artifact_uri
1294+
self.training_instance_type_variants: Optional[
1295+
JumpStartInstanceTypeVariants
1296+
] = JumpStartInstanceTypeVariants(
1297+
hub_content_document.training_instance_type_variants
1298+
if hub_content_document.training_instance_type_variants
1299+
else None
1300+
)
1301+
11591302
def supports_prepacked_inference(self) -> bool:
11601303
"""Returns True if the model has a prepacked inference artifact."""
11611304
return getattr(self, "hosting_prepacked_artifact_key", None) is not None

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,38 @@ def test_generate_hub_arn_for_init_kwargs():
147147
utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn
148148
)
149149

150+
assert (
151+
utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session)
152+
== hub_arn
153+
)
154+
155+
156+
def test_generate_default_hub_bucket_name():
157+
mock_sagemaker_session = Mock()
158+
mock_sagemaker_session.account_id.return_value = "123456789123"
159+
mock_sagemaker_session.boto_region_name = "us-east-1"
160+
161+
assert (
162+
utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session)
163+
== "sagemaker-hubs-us-east-1-123456789123"
164+
)
165+
166+
167+
def test_create_hub_bucket_if_it_does_not_exist():
168+
mock_sagemaker_session = Mock()
169+
mock_sagemaker_session.account_id.return_value = "123456789123"
170+
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
171+
"Account": "123456789123"
172+
}
173+
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
174+
mock_sagemaker_session.boto_region_name = "us-east-1"
175+
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
176+
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
177+
sagemaker_session=mock_sagemaker_session
178+
)
179+
180+
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
181+
assert created_hub_bucket_name == bucket_name
150182
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn
151183

152184

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
2929
JumpStartModelsCache,
3030
)
31+
from sagemaker.session_settings import SessionSettings
3132
from sagemaker.jumpstart.constants import (
3233
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
3334
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
@@ -1133,7 +1134,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
11331134

11341135
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
11351136
assert mocked_is_dir.call_count == 2
1136-
mocked_open.assert_not_called()
1137+
assert mocked_open.call_count == 2
11371138
mocked_get_json_file_and_etag_from_s3.assert_has_calls(
11381139
calls=[
11391140
call("models_manifest.json"),

0 commit comments

Comments
 (0)