Skip to content

Commit e710f1d

Browse files
authored
feat: Marketplace model support in HubService (#4916)
* feat: Marketplace model support in HubService * fix: removing field * fix: Reverting name change for code coverage * fix: Adding more code coverage * fix: linting * fix: Fixing coverage tests * fix: Fixing integration tests * fix: Minor fixes
1 parent 4e1a37e commit e710f1d

File tree

11 files changed

+350
-46
lines changed

11 files changed

+350
-46
lines changed

src/sagemaker/jumpstart/hub/interfaces.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -451,12 +451,14 @@ def from_json(self, json_obj: str) -> None:
451451
class HubModelDocument(HubDataHolderType):
452452
"""Data class for model type HubContentDocument from session.describe_hub_content()."""
453453

454-
SCHEMA_VERSION = "2.2.0"
454+
SCHEMA_VERSION = "2.3.0"
455455

456456
__slots__ = [
457457
"url",
458458
"min_sdk_version",
459459
"training_supported",
460+
"model_types",
461+
"capabilities",
460462
"incremental_training_supported",
461463
"dynamic_container_deployment_supported",
462464
"hosting_ecr_uri",
@@ -469,6 +471,7 @@ class HubModelDocument(HubDataHolderType):
469471
"hosting_use_script_uri",
470472
"hosting_eula_uri",
471473
"hosting_model_package_arn",
474+
"model_subscription_link",
472475
"inference_configs",
473476
"inference_config_components",
474477
"inference_config_rankings",
@@ -549,18 +552,22 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
549552
Args:
550553
json_obj (Dict[str, Any]): Dictionary representation of hub model document.
551554
"""
552-
self.url: str = json_obj["Url"]
553-
self.min_sdk_version: str = json_obj["MinSdkVersion"]
554-
self.hosting_ecr_uri: Optional[str] = json_obj["HostingEcrUri"]
555-
self.hosting_artifact_uri = json_obj["HostingArtifactUri"]
556-
self.hosting_script_uri = json_obj["HostingScriptUri"]
557-
self.inference_dependencies: List[str] = json_obj["InferenceDependencies"]
555+
self.url: str = json_obj.get("Url")
556+
self.min_sdk_version: str = json_obj.get("MinSdkVersion")
557+
self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri")
558+
self.hosting_artifact_uri = json_obj.get("HostingArtifactUri")
559+
self.hosting_script_uri = json_obj.get("HostingScriptUri")
560+
self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies")
558561
self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [
559562
JumpStartEnvironmentVariable(env_variable, is_hub_content=True)
560-
for env_variable in json_obj["InferenceEnvironmentVariables"]
563+
for env_variable in json_obj.get("InferenceEnvironmentVariables", [])
561564
]
562-
self.training_supported: bool = bool(json_obj["TrainingSupported"])
563-
self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"])
565+
self.model_types: Optional[List[str]] = json_obj.get("ModelTypes")
566+
self.capabilities: Optional[List[str]] = json_obj.get("Capabilities")
567+
self.training_supported: bool = bool(json_obj.get("TrainingSupported"))
568+
self.incremental_training_supported: bool = bool(
569+
json_obj.get("IncrementalTrainingSupported")
570+
)
564571
self.dynamic_container_deployment_supported: Optional[bool] = (
565572
bool(json_obj.get("DynamicContainerDeploymentSupported"))
566573
if json_obj.get("DynamicContainerDeploymentSupported")
@@ -586,6 +593,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
586593
self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri")
587594
self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn")
588595

596+
self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink")
597+
589598
self.inference_config_rankings = self._get_config_rankings(json_obj)
590599
self.inference_config_components = self._get_config_components(json_obj)
591600
self.inference_configs = self._get_configs(json_obj)

src/sagemaker/jumpstart/hub/parser_utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919

2020

2121
def camel_to_snake(camel_case_string: str) -> str:
22-
"""Converts camelCaseString or UpperCamelCaseString to snake_case_string."""
23-
snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string)
24-
if "-" in snake_case_string:
25-
# remove any hyphen from the string for accurate conversion.
26-
snake_case_string = snake_case_string.replace("-", "")
27-
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower()
22+
"""Converts PascalCase to snake_case_string using a regex.
23+
24+
This regex cannot handle whitespace ("PascalString TwoWords")
25+
"""
26+
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()
2827

2928

3029
def snake_to_upper_camel(snake_case_string: str) -> str:

src/sagemaker/jumpstart/hub/parsers.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def make_model_specs_from_describe_hub_content_response(
137137
hub_model_document: HubModelDocument = response.hub_content_document
138138
specs["url"] = hub_model_document.url
139139
specs["min_sdk_version"] = hub_model_document.min_sdk_version
140+
specs["model_types"] = hub_model_document.model_types
141+
specs["capabilities"] = hub_model_document.capabilities
140142
specs["training_supported"] = bool(hub_model_document.training_supported)
141143
specs["incremental_training_supported"] = bool(
142144
hub_model_document.incremental_training_supported
@@ -146,15 +148,19 @@ def make_model_specs_from_describe_hub_content_response(
146148
specs["inference_config_components"] = hub_model_document.inference_config_components
147149
specs["inference_config_rankings"] = hub_model_document.inference_config_rankings
148150

149-
hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
150-
hub_model_document.hosting_artifact_uri
151-
)
152-
specs["hosting_artifact_key"] = hosting_artifact_key
153-
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
154-
hosting_script_bucket, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
155-
hub_model_document.hosting_script_uri
156-
)
157-
specs["hosting_script_key"] = hosting_script_key
151+
if hub_model_document.hosting_artifact_uri:
152+
_, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
153+
hub_model_document.hosting_artifact_uri
154+
)
155+
specs["hosting_artifact_key"] = hosting_artifact_key
156+
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
157+
158+
if hub_model_document.hosting_script_uri:
159+
_, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
160+
hub_model_document.hosting_script_uri
161+
)
162+
specs["hosting_script_key"] = hosting_script_key
163+
158164
specs["inference_environment_variables"] = hub_model_document.inference_environment_variables
159165
specs["inference_vulnerable"] = False
160166
specs["inference_dependencies"] = hub_model_document.inference_dependencies
@@ -220,6 +226,8 @@ def make_model_specs_from_describe_hub_content_response(
220226
if hub_model_document.hosting_model_package_arn:
221227
specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn}
222228

229+
specs["model_subscription_link"] = hub_model_document.model_subscription_link
230+
223231
specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri
224232

225233
specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""This module contains utilities related to SageMaker JumpStart Hub."""
1515
from __future__ import absolute_import
1616
import re
17-
from typing import Optional
17+
from typing import Optional, List, Any
1818
from sagemaker.jumpstart.hub.types import S3ObjectLocation
1919
from sagemaker.s3_utils import parse_s3_url
2020
from sagemaker.session import Session
@@ -23,6 +23,14 @@
2323
from sagemaker.jumpstart import constants
2424
from packaging.specifiers import SpecifierSet, InvalidSpecifier
2525

26+
PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"
27+
28+
29+
def _convert_str_to_optional(string: str) -> Optional[str]:
30+
if string == "None":
31+
string = None
32+
return string
33+
2634

2735
def get_info_from_hub_resource_arn(
2836
arn: str,
@@ -37,7 +45,7 @@ def get_info_from_hub_resource_arn(
3745
hub_name = match.group(4)
3846
hub_content_type = match.group(5)
3947
hub_content_name = match.group(6)
40-
hub_content_version = match.group(7)
48+
hub_content_version = _convert_str_to_optional(match.group(7))
4149

4250
return HubArnExtractedInfo(
4351
partition=partition,
@@ -194,10 +202,14 @@ def get_hub_model_version(
194202
hub_model_version: Optional[str] = None,
195203
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
196204
) -> str:
197-
"""Returns available Jumpstart hub model version
205+
"""Returns available Jumpstart hub model version.
206+
207+
It will attempt both a semantic HubContent version search and Marketplace version search.
208+
If the Marketplace version is also semantic, this function will default to HubContent version.
198209
199210
Raises:
200211
ClientError: If the specified model is not found in the hub.
212+
KeyError: If the specified model version is not found.
201213
"""
202214

203215
try:
@@ -207,6 +219,22 @@ def get_hub_model_version(
207219
except Exception as ex:
208220
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
209221

222+
try:
223+
return _get_hub_model_version_for_open_weight_version(
224+
hub_content_summaries, hub_model_version
225+
)
226+
except KeyError:
227+
marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version(
228+
hub_content_summaries, hub_model_version
229+
)
230+
if marketplace_hub_content_version:
231+
return marketplace_hub_content_version
232+
raise
233+
234+
235+
def _get_hub_model_version_for_open_weight_version(
236+
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
237+
) -> str:
210238
available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]
211239

212240
if hub_model_version == "*" or hub_model_version is None:
@@ -222,3 +250,37 @@ def get_hub_model_version(
222250
hub_model_version = str(max(available_versions_filtered))
223251

224252
return hub_model_version
253+
254+
255+
def _get_hub_model_version_for_marketplace_version(
256+
hub_content_summaries: List[Any], marketplace_version: str
257+
) -> Optional[str]:
258+
"""Returns the HubContent version associated with the Marketplace version.
259+
260+
This function will check within the HubContentSearchKeywords for the proprietary version.
261+
"""
262+
for model in hub_content_summaries:
263+
model_search_keywords = model.get("HubContentSearchKeywords", [])
264+
if _hub_search_keywords_contains_marketplace_version(
265+
model_search_keywords, marketplace_version
266+
):
267+
return model.get("HubContentVersion")
268+
269+
return None
270+
271+
272+
def _hub_search_keywords_contains_marketplace_version(
273+
model_search_keywords: List[str], marketplace_version: str
274+
) -> bool:
275+
proprietary_version_keyword = next(
276+
filter(lambda s: s.startswith(PROPRIETARY_VERSION_KEYWORD), model_search_keywords), None
277+
)
278+
279+
if not proprietary_version_keyword:
280+
return False
281+
282+
proprietary_version = proprietary_version_keyword.lstrip(PROPRIETARY_VERSION_KEYWORD)
283+
if proprietary_version == marketplace_version:
284+
return True
285+
286+
return False

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,6 +1200,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
12001200
"url",
12011201
"version",
12021202
"min_sdk_version",
1203+
"model_types",
1204+
"capabilities",
12031205
"incremental_training_supported",
12041206
"hosting_ecr_specs",
12051207
"hosting_ecr_uri",
@@ -1287,6 +1289,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12871289
json_obj.get("incremental_training_supported", False)
12881290
)
12891291
if self._is_hub_content:
1292+
self.capabilities: Optional[List[str]] = json_obj.get("capabilities")
1293+
self.model_types: Optional[List[str]] = json_obj.get("model_types")
12901294
self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri")
12911295
self._non_serializable_slots.append("hosting_ecr_specs")
12921296
else:

src/sagemaker/jumpstart/utils.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,16 @@ def validate_model_id_and_get_type(
856856
if not isinstance(model_id, str):
857857
return None
858858
if hub_arn:
859-
return None
859+
model_types = _validate_hub_service_model_id_and_get_type(
860+
model_id=model_id,
861+
hub_arn=hub_arn,
862+
region=region,
863+
model_version=model_version,
864+
sagemaker_session=sagemaker_session,
865+
)
866+
return (
867+
model_types[0] if model_types else None
868+
) # Currently this function only supports one model type
860869

861870
s3_client = sagemaker_session.s3_client if sagemaker_session else None
862871
region = region or constants.JUMPSTART_DEFAULT_REGION_NAME
@@ -881,6 +890,37 @@ def validate_model_id_and_get_type(
881890
return None
882891

883892

893+
def _validate_hub_service_model_id_and_get_type(
894+
model_id: Optional[str],
895+
hub_arn: str,
896+
region: Optional[str] = None,
897+
model_version: Optional[str] = None,
898+
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
899+
) -> List[enums.JumpStartModelType]:
900+
"""Returns a list of JumpStartModelType based off the HubContent.
901+
902+
Only returns valid JumpStartModelType. Returns an empty array if none are found.
903+
"""
904+
hub_model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
905+
region=region,
906+
model_id=model_id,
907+
version=model_version,
908+
hub_arn=hub_arn,
909+
sagemaker_session=sagemaker_session,
910+
)
911+
912+
hub_content_model_types = []
913+
model_types_field: Optional[List[str]] = getattr(hub_model_specs, "model_types", [])
914+
model_types = model_types_field if model_types_field else []
915+
for model_type in model_types:
916+
try:
917+
hub_content_model_types.append(enums.JumpStartModelType[model_type])
918+
except ValueError:
919+
continue
920+
921+
return hub_content_model_types
922+
923+
884924
def _extract_value_from_list_of_tags(
885925
tag_keys: List[str],
886926
list_tags_result: List[str],

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
"ap-southeast-2",
5454
}
5555

56+
TEST_HUB_WITH_REFERENCE = "mock-hub-name"
57+
5658

5759
def test_non_prepacked_jumpstart_model(setup):
5860

tests/integ/sagemaker/jumpstart/utils.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,23 +53,18 @@ def get_sm_session() -> Session:
5353
return Session(boto_session=boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME))
5454

5555

56-
# def get_sm_session_with_override() -> Session:
57-
# # [TODO]: Remove service endpoint override before GA
58-
# # boto3.set_stream_logger(name='botocore', level=logging.DEBUG)
59-
# boto_session = boto3.Session(region_name="us-west-2")
60-
# sagemaker = boto3.client(
61-
# service_name="sagemaker-internal",
62-
# endpoint_url="https://sagemaker.beta.us-west-2.ml-platform.aws.a2z.com",
63-
# )
64-
# sagemaker_runtime = boto3.client(
65-
# service_name="runtime.maeve",
66-
# endpoint_url="https://maeveruntime.beta.us-west-2.ml-platform.aws.a2z.com",
67-
# )
68-
# return Session(
69-
# boto_session=boto_session,
70-
# sagemaker_client=sagemaker,
71-
# sagemaker_runtime_client=sagemaker_runtime,
72-
# )
56+
def get_sm_session_with_override() -> Session:
57+
# [TODO]: Remove service endpoint override before GA
58+
# boto3.set_stream_logger(name='botocore', level=logging.DEBUG)
59+
boto_session = boto3.Session(region_name="us-west-2")
60+
sagemaker = boto3.client(
61+
service_name="sagemaker",
62+
endpoint_url="https://sagemaker.gamma.us-west-2.ml-platform.aws.a2z.com",
63+
)
64+
return Session(
65+
boto_session=boto_session,
66+
sagemaker_client=sagemaker,
67+
)
7368

7469

7570
def get_training_dataset_for_model_and_version(model_id: str, version: str) -> dict:

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9178,6 +9178,7 @@
91789178
"TrainingArtifactS3DataType": "S3Prefix",
91799179
"TrainingArtifactCompressionType": "None",
91809180
"TrainingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501
9181+
"ModelTypes": ["OPEN_WEIGHTS", "PROPRIETARY"],
91819182
"Hyperparameters": [
91829183
{
91839184
"Name": "peft_type",

0 commit comments

Comments
 (0)