Skip to content

Commit ce43606

Browse files
authored
Feature: register proprietary models from jumpstart (#4753)
* Feature: register proprietary models from jumpstart Feature: register proprietary models from jumpstart * fix: register jumpstart models on model registry
1 parent f799f15 commit ce43606

File tree

6 files changed

+99
-3
lines changed

6 files changed

+99
-3
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,7 @@ def get_register_kwargs(
693693
model_id: str,
694694
model_version: Optional[str] = None,
695695
hub_arn: Optional[str] = None,
696+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
696697
region: Optional[str] = None,
697698
tolerate_deprecated_model: Optional[bool] = None,
698699
tolerate_vulnerable_model: Optional[bool] = None,
@@ -720,13 +721,15 @@ def get_register_kwargs(
720721
skip_model_validation: Optional[str] = None,
721722
source_uri: Optional[str] = None,
722723
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
724+
accept_eula: Optional[bool] = None,
723725
) -> JumpStartModelRegisterKwargs:
724726
"""Returns kwargs required to call `register` on `sagemaker.estimator.Model` object."""
725727

726728
register_kwargs = JumpStartModelRegisterKwargs(
727729
model_id=model_id,
728730
model_version=model_version,
729731
hub_arn=hub_arn,
732+
model_type=model_type,
730733
region=region,
731734
tolerate_deprecated_model=tolerate_deprecated_model,
732735
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -754,12 +757,14 @@ def get_register_kwargs(
754757
skip_model_validation=skip_model_validation,
755758
source_uri=source_uri,
756759
model_card=model_card,
760+
accept_eula=accept_eula,
757761
)
758762

759763
model_specs = verify_model_region_and_return_specs(
760764
model_id=model_id,
761765
version=model_version,
762766
hub_arn=hub_arn,
767+
model_type=model_type,
763768
region=region,
764769
scope=JumpStartScriptScope.INFERENCE,
765770
sagemaker_session=sagemaker_session,

src/sagemaker/jumpstart/model.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ def register(
760760
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
761761
source_uri: Optional[Union[str, PipelineVariable]] = None,
762762
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
763+
accept_eula: Optional[bool] = None,
763764
):
764765
"""Creates a model package for creating SageMaker models or listing on Marketplace.
765766
@@ -809,15 +810,25 @@ def register(
809810
(default: None).
810811
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
811812
quantitative information about a model (default: None).
812-
813+
accept_eula (bool): For models that require a Model Access Config, specify True or
814+
False to indicate whether model terms of use have been accepted.
815+
The `accept_eula` value must be explicitly defined as `True` in order to
816+
accept the end-user license agreement (EULA) that some
817+
models require. (Default: None).
813818
Returns:
814819
A `sagemaker.model.ModelPackage` instance.
815820
"""
816821

822+
if model_package_group_name is None:
823+
model_package_group_name = self.model_id
824+
if self.model_type is JumpStartModelType.PROPRIETARY:
825+
source_uri = self.model_package_arn
826+
817827
register_kwargs = get_register_kwargs(
818828
model_id=self.model_id,
819829
model_version=self.model_version,
820830
hub_arn=self.hub_arn,
831+
model_type=self.model_type,
821832
region=self.region,
822833
tolerate_deprecated_model=self.tolerate_deprecated_model,
823834
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
@@ -845,6 +856,7 @@ def register(
845856
skip_model_validation=skip_model_validation,
846857
source_uri=source_uri,
847858
model_card=model_card,
859+
accept_eula=accept_eula,
848860
)
849861

850862
model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())

src/sagemaker/jumpstart/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2372,6 +2372,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
23722372
"tolerate_deprecated_model",
23732373
"region",
23742374
"model_id",
2375+
"model_type",
23752376
"model_version",
23762377
"hub_arn",
23772378
"sagemaker_session",
@@ -2398,6 +2399,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
23982399
"skip_model_validation",
23992400
"source_uri",
24002401
"model_card",
2402+
"accept_eula",
24012403
]
24022404

24032405
SERIALIZATION_EXCLUSION_SET = {
@@ -2416,6 +2418,7 @@ def __init__(
24162418
model_version: Optional[str] = None,
24172419
hub_arn: Optional[str] = None,
24182420
region: Optional[str] = None,
2421+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
24192422
tolerate_deprecated_model: Optional[bool] = None,
24202423
tolerate_vulnerable_model: Optional[bool] = None,
24212424
sagemaker_session: Optional[Any] = None,
@@ -2442,12 +2445,14 @@ def __init__(
24422445
skip_model_validation: Optional[str] = None,
24432446
source_uri: Optional[str] = None,
24442447
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
2448+
accept_eula: Optional[bool] = None,
24452449
) -> None:
24462450
"""Instantiates JumpStartModelRegisterKwargs object."""
24472451

24482452
self.model_id = model_id
24492453
self.model_version = model_version
24502454
self.hub_arn = hub_arn
2455+
self.model_type = model_type
24512456
self.region = region
24522457
self.image_uri = image_uri
24532458
self.sagemaker_session = sagemaker_session
@@ -2476,3 +2481,4 @@ def __init__(
24762481
self.skip_model_validation = skip_model_validation
24772482
self.source_uri = source_uri
24782483
self.model_card = model_card
2484+
self.accept_eula = accept_eula

src/sagemaker/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH,
4545
load_sagemaker_config,
4646
)
47+
from sagemaker.jumpstart.enums import JumpStartModelType
4748
from sagemaker.model_card import (
4849
ModelCard,
4950
ModelPackageModelCard,
@@ -449,6 +450,8 @@ def register(
449450
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
450451
source_uri: Optional[Union[str, PipelineVariable]] = None,
451452
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
453+
accept_eula: Optional[bool] = None,
454+
model_type: Optional[JumpStartModelType] = None,
452455
):
453456
"""Creates a model package for creating SageMaker models or listing on Marketplace.
454457
@@ -523,9 +526,8 @@ def register(
523526
model_package_group_name = utils.base_name_from_image(
524527
self.image_uri, default_base_name=ModelPackage.__name__
525528
)
526-
527529
if model_package_group_name is not None:
528-
container_def = self.prepare_container_def()
530+
container_def = self.prepare_container_def(accept_eula=accept_eula)
529531
container_def = update_container_with_inference_params(
530532
framework=framework,
531533
framework_version=framework_version,

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ def test_jumpstart_model_register(setup):
265265

266266
response = predictor.predict("hello world!")
267267

268+
predictor.delete_predictor()
269+
268270
assert response is not None
269271

270272

@@ -291,3 +293,59 @@ def test_proprietary_jumpstart_model(setup):
291293
response = predictor.predict(payload)
292294

293295
assert response is not None
296+
297+
298+
@pytest.mark.skipif(
299+
True,
300+
reason="Only enable if test account is subscribed to the proprietary model",
301+
)
302+
def test_register_proprietary_jumpstart_model(setup):
303+
304+
model_id = "ai21-jurassic-2-light"
305+
306+
model = JumpStartModel(
307+
model_id=model_id,
308+
model_version="2.0.004",
309+
role=get_sm_session().get_caller_identity_arn(),
310+
sagemaker_session=get_sm_session(),
311+
)
312+
model_package = model.register()
313+
314+
predictor = model_package.deploy(
315+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}]
316+
)
317+
payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1}
318+
319+
response = predictor.predict(payload)
320+
321+
predictor.delete_predictor()
322+
323+
assert response is not None
324+
325+
326+
@pytest.mark.skipif(
327+
True,
328+
reason="Only enable if test account is subscribed to the proprietary model",
329+
)
330+
def test_register_gated_jumpstart_model(setup):
331+
332+
model_id = "meta-textgenerationneuron-llama-2-7b"
333+
model = JumpStartModel(
334+
model_id=model_id,
335+
model_version="1.1.0",
336+
role=get_sm_session().get_caller_identity_arn(),
337+
sagemaker_session=get_sm_session(),
338+
)
339+
model_package = model.register(accept_eula=True)
340+
341+
predictor = model_package.deploy(
342+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
343+
accept_eula=True,
344+
)
345+
payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1}
346+
347+
response = predictor.predict(payload)
348+
349+
predictor.delete_predictor()
350+
351+
assert response is not None

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,11 @@ def test_eula_gated_conditional_s3_prefix_metadata_model(
473473
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
474474
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
475475
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
476+
@mock.patch("sagemaker.jumpstart.model.Model.register")
476477
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
477478
def test_proprietary_model_endpoint(
478479
self,
480+
mock_model_register: mock.Mock,
479481
mock_model_deploy: mock.Mock,
480482
mock_model_init: mock.Mock,
481483
mock_get_model_specs: mock.Mock,
@@ -507,8 +509,17 @@ def test_proprietary_model_endpoint(
507509
enable_network_isolation=False,
508510
)
509511

512+
model.register()
510513
model.deploy()
511514

515+
mock_model_register.assert_called_once_with(
516+
model_type=JumpStartModelType.PROPRIETARY,
517+
content_types=["application/json"],
518+
response_types=["application/json"],
519+
model_package_group_name=model_id,
520+
source_uri=model.model_package_arn,
521+
)
522+
512523
mock_model_deploy.assert_called_once_with(
513524
initial_instance_count=1,
514525
instance_type="ml.p4de.24xlarge",
@@ -1408,8 +1419,10 @@ def test_model_registry_accept_and_response_types(
14081419
model.register()
14091420

14101421
mock_model_register.assert_called_once_with(
1422+
model_type=JumpStartModelType.OPEN_WEIGHTS,
14111423
content_types=["application/x-text"],
14121424
response_types=["application/json;verbose", "application/json"],
1425+
model_package_group_name=model.model_id,
14131426
)
14141427

14151428
@mock.patch("sagemaker.jumpstart.model.get_default_predictor")

0 commit comments

Comments
 (0)