Skip to content

Commit 3eb2dc4

Browse files
Captainialiujiaorr
authored andcommitted
feat: support JumpStart proprietary models (aws#4467)
* feat: add proprietary manifest/specs parsing add unittests for test_cache small refactoring address comments and more unittests fix linting and fix more tests fix: pylint feat: JumpStartModel class for prop models * remove unused imports and fix docstyle * fix: remove unused args * fix: remove unused args * fix: more unused vars * fix: slow tests * fix: unittests * added more tests to cover some lines * remove estimator warn check * chore: address comments re performance * fix: address comments * complete list experience and other fixes * fix: pylint * add doc utils and fix pylint * fix: docstyle * fix: doc * fix: default payloads * fix: doc and tags and enums * fix: jumpstart doc * rename to open_weights and fix filtering * update filter name * doc update * fix: black * rename to proprietary model and fix unittests * address comments * fix: docstyle and flake8 * address more comments and fix doc * put back doc utils for future refactoring * add prop model title in doc * doc update --------- Co-authored-by: liujiaor <[email protected]>
1 parent 4411f4c commit 3eb2dc4

File tree

6 files changed

+8
-4
lines changed

6 files changed

+8
-4
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,6 @@ def _retrieval_function(
411411
"""
412412

413413
data_type, id_info = key.data_type, key.id_info
414-
415414
if data_type in {
416415
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
417416
JumpStartS3FileType.PROPRIETARY_MANIFEST,
@@ -425,7 +424,6 @@ def _retrieval_function(
425424
formatted_content=utils.get_formatted_manifest(formatted_body),
426425
md5_hash=etag,
427426
)
428-
429427
if data_type in {
430428
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
431429
JumpStartS3FileType.PROPRIETARY_SPECS,

src/sagemaker/jumpstart/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2368,6 +2368,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
23682368
"model_version",
23692369
"model_type",
23702370
"hub_arn",
2371+
"model_type",
23712372
"region",
23722373
"tolerate_deprecated_model",
23732374
"tolerate_vulnerable_model",

tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode
126126
model_type=JumpStartModelType.OPEN_WEIGHTS,
127127
hub_arn=None,
128128
s3_client=mock_client,
129+
model_type=JumpStartModelType.OPEN_WEIGHTS,
129130
)
130131

131132
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
HubContentType,
3232
)
3333
from sagemaker.jumpstart.enums import JumpStartModelType
34-
3534
from sagemaker.jumpstart.utils import get_formatted_manifest
3635
from tests.unit.sagemaker.jumpstart.constants import (
3736
PROTOTYPICAL_MODEL_SPECS_DICT,

tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,12 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode
110110
}
111111

112112

113+
@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type")
113114
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
114-
def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs):
115+
def test_jumpstart_no_supported_resource_requirements(
116+
patched_get_model_specs, patched_validate_model_id_and_get_type
117+
):
118+
115119
patched_get_model_specs.side_effect = get_special_model_spec
116120
region = "us-west-2"
117121
mock_client = boto3.client("s3")

tests/unit/sagemaker/script_uris/jumpstart/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_jumpstart_common_script_uri(
5454
s3_client=mock_client,
5555
model_type=JumpStartModelType.OPEN_WEIGHTS,
5656
hub_arn=None,
57+
model_type=JumpStartModelType.OPEN_WEIGHTS,
5758
)
5859
patched_verify_model_region_and_return_specs.assert_called_once()
5960

0 commit comments

Comments
 (0)