Skip to content

Commit a10f00c

Browse files
committed
Feature: register proprietary models from jumpstart
Feature: register proprietary models from jumpstart
1 parent 3f9acac commit a10f00c

File tree

6 files changed

+113
-18
lines changed

6 files changed

+113
-18
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,7 @@ def get_register_kwargs(
693693
model_id: str,
694694
model_version: Optional[str] = None,
695695
hub_arn: Optional[str] = None,
696+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
696697
region: Optional[str] = None,
697698
tolerate_deprecated_model: Optional[bool] = None,
698699
tolerate_vulnerable_model: Optional[bool] = None,
@@ -720,13 +721,15 @@ def get_register_kwargs(
720721
skip_model_validation: Optional[str] = None,
721722
source_uri: Optional[str] = None,
722723
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
724+
accept_eula: Optional[bool] = None,
723725
) -> JumpStartModelRegisterKwargs:
724726
"""Returns kwargs required to call `register` on `sagemaker.estimator.Model` object."""
725727

726728
register_kwargs = JumpStartModelRegisterKwargs(
727729
model_id=model_id,
728730
model_version=model_version,
729731
hub_arn=hub_arn,
732+
model_type=model_type,
730733
region=region,
731734
tolerate_deprecated_model=tolerate_deprecated_model,
732735
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -754,12 +757,14 @@ def get_register_kwargs(
754757
skip_model_validation=skip_model_validation,
755758
source_uri=source_uri,
756759
model_card=model_card,
760+
accept_eula=accept_eula,
757761
)
758762

759763
model_specs = verify_model_region_and_return_specs(
760764
model_id=model_id,
761765
version=model_version,
762766
hub_arn=hub_arn,
767+
model_type=model_type,
763768
region=region,
764769
scope=JumpStartScriptScope.INFERENCE,
765770
sagemaker_session=sagemaker_session,

src/sagemaker/jumpstart/model.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,8 @@ def register(
760760
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
761761
source_uri: Optional[Union[str, PipelineVariable]] = None,
762762
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
763+
accept_eula: Optional[bool] = None,
764+
763765
):
764766
"""Creates a model package for creating SageMaker models or listing on Marketplace.
765767
@@ -809,15 +811,24 @@ def register(
809811
(default: None).
810812
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
811813
quantitative information about a model (default: None).
812-
814+
accept_eula (bool): For models that require a Model Access Config, specify True or
815+
False to indicate whether model terms of use have been accepted.
816+
The `accept_eula` value must be explicitly defined as `True` in order to
817+
accept the end-user license agreement (EULA) that some
818+
models require. (Default: None).
813819
Returns:
814820
A `sagemaker.model.ModelPackage` instance.
815821
"""
816822

823+
if model_package_group_name is None and self.model_type is JumpStartModelType.PROPRIETARY:
824+
model_package_group_name = self.model_id
825+
source_uri = self.model_package_arn
826+
817827
register_kwargs = get_register_kwargs(
818828
model_id=self.model_id,
819829
model_version=self.model_version,
820830
hub_arn=self.hub_arn,
831+
model_type=self.model_type,
821832
region=self.region,
822833
tolerate_deprecated_model=self.tolerate_deprecated_model,
823834
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
@@ -845,6 +856,7 @@ def register(
845856
skip_model_validation=skip_model_validation,
846857
source_uri=source_uri,
847858
model_card=model_card,
859+
accept_eula=accept_eula,
848860
)
849861

850862
model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())

src/sagemaker/jumpstart/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2372,6 +2372,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
23722372
"tolerate_deprecated_model",
23732373
"region",
23742374
"model_id",
2375+
"model_type",
23752376
"model_version",
23762377
"hub_arn",
23772378
"sagemaker_session",
@@ -2398,6 +2399,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
23982399
"skip_model_validation",
23992400
"source_uri",
24002401
"model_card",
2402+
"accept_eula",
24012403
]
24022404

24032405
SERIALIZATION_EXCLUSION_SET = {
@@ -2408,6 +2410,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
24082410
"model_version",
24092411
"hub_arn",
24102412
"sagemaker_session",
2413+
"model_type",
24112414
}
24122415

24132416
def __init__(
@@ -2416,6 +2419,7 @@ def __init__(
24162419
model_version: Optional[str] = None,
24172420
hub_arn: Optional[str] = None,
24182421
region: Optional[str] = None,
2422+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
24192423
tolerate_deprecated_model: Optional[bool] = None,
24202424
tolerate_vulnerable_model: Optional[bool] = None,
24212425
sagemaker_session: Optional[Any] = None,
@@ -2442,12 +2446,14 @@ def __init__(
24422446
skip_model_validation: Optional[str] = None,
24432447
source_uri: Optional[str] = None,
24442448
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
2449+
accept_eula: Optional[bool] = None,
24452450
) -> None:
24462451
"""Instantiates JumpStartModelRegisterKwargs object."""
24472452

24482453
self.model_id = model_id
24492454
self.model_version = model_version
24502455
self.hub_arn = hub_arn
2456+
self.model_type = model_type
24512457
self.region = region
24522458
self.image_uri = image_uri
24532459
self.sagemaker_session = sagemaker_session
@@ -2476,3 +2482,4 @@ def __init__(
24762482
self.skip_model_validation = skip_model_validation
24772483
self.source_uri = source_uri
24782484
self.model_card = model_card
2485+
self.accept_eula = accept_eula

src/sagemaker/model.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH,
4545
load_sagemaker_config,
4646
)
47+
from sagemaker.jumpstart.enums import JumpStartModelType
4748
from sagemaker.model_card import (
4849
ModelCard,
4950
ModelPackageModelCard,
@@ -448,6 +449,7 @@ def register(
448449
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
449450
source_uri: Optional[Union[str, PipelineVariable]] = None,
450451
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
452+
accept_eula: Optional[bool] = None,
451453
):
452454
"""Creates a model package for creating SageMaker models or listing on Marketplace.
453455
@@ -515,23 +517,22 @@ def register(
515517

516518
if image_uri is not None:
517519
self.image_uri = image_uri
518-
519-
if model_package_group_name is None and model_package_name is None:
520-
# If model package group and model package name is not set
521-
# then register to auto-generated model package group
522-
model_package_group_name = utils.base_name_from_image(
523-
self.image_uri, default_base_name=ModelPackage.__name__
524-
)
525-
526-
if model_package_group_name is not None:
527-
container_def = self.prepare_container_def()
528-
container_def = update_container_with_inference_params(
529-
framework=framework,
530-
framework_version=framework_version,
531-
nearest_model_name=nearest_model_name,
532-
data_input_configuration=data_input_configuration,
533-
container_def=container_def,
534-
)
520+
if self.model_type is not JumpStartModelType.PROPRIETARY:
521+
if model_package_group_name is None and model_package_name is None:
522+
# If model package group and model package name is not set
523+
# then register to auto-generated model package group
524+
model_package_group_name = utils.base_name_from_image(
525+
self.image_uri, default_base_name=ModelPackage.__name__
526+
)
527+
if model_package_group_name is not None:
528+
container_def = self.prepare_container_def(accept_eula=accept_eula)
529+
container_def = update_container_with_inference_params(
530+
framework=framework,
531+
framework_version=framework_version,
532+
nearest_model_name=nearest_model_name,
533+
data_input_configuration=data_input_configuration,
534+
container_def=container_def,
535+
)
535536
else:
536537
container_def = {
537538
"Image": self.image_uri,
@@ -546,6 +547,10 @@ def register(
546547
if self.model_data is not None:
547548
container_def["ModelDataUrl"] = self.model_data
548549

550+
if self.model_type is JumpStartModelType.PROPRIETARY:
551+
source_uri = self.model_package_arn
552+
model_package_group_name = self.model_id
553+
549554
model_pkg_args = sagemaker.get_model_package_args(
550555
self.content_types,
551556
self.response_types,

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ def test_jumpstart_model_register(setup):
265265

266266
response = predictor.predict("hello world!")
267267

268+
predictor.delete_predictor()
269+
268270
assert response is not None
269271

270272

@@ -291,3 +293,59 @@ def test_proprietary_jumpstart_model(setup):
291293
response = predictor.predict(payload)
292294

293295
assert response is not None
296+
297+
@pytest.mark.skipif(
298+
True,
299+
reason="Only enable if test account is subscribed to the proprietary model",
300+
)
301+
def test_register_proprietary_jumpstart_model(setup):
302+
303+
model_id = "ai21-jurassic-2-light"
304+
305+
model = JumpStartModel(
306+
model_id=model_id,
307+
model_version="2.0.004",
308+
role=get_sm_session().get_caller_identity_arn(),
309+
sagemaker_session=get_sm_session(),
310+
)
311+
model_package = model.register()
312+
313+
314+
predictor = model_package.deploy(
315+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}]
316+
)
317+
payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1}
318+
319+
response = predictor.predict(payload)
320+
321+
predictor.delete_predictor()
322+
323+
assert response is not None
324+
325+
326+
@pytest.mark.skipif(
327+
True,
328+
reason="Only enable if test account is subscribed to the proprietary model",
329+
)
330+
def test_register_gated_jumpstart_model(setup):
331+
332+
model_id="meta-textgenerationneuron-llama-2-7b"
333+
model = JumpStartModel(
334+
model_id=model_id,
335+
model_version="1.1.0",
336+
role=get_sm_session().get_caller_identity_arn(),
337+
sagemaker_session=get_sm_session(),
338+
)
339+
model_package = model.register(accept_eula=True)
340+
341+
predictor = model_package.deploy(
342+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], accept_eula=True
343+
)
344+
payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1}
345+
346+
response = predictor.predict(payload)
347+
348+
predictor.delete_predictor()
349+
350+
assert response is not None
351+

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,11 @@ def test_eula_gated_conditional_s3_prefix_metadata_model(
473473
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
474474
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
475475
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
476+
@mock.patch("sagemaker.jumpstart.model.Model.register")
476477
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
477478
def test_proprietary_model_endpoint(
478479
self,
480+
mock_model_register: mock.Mock,
479481
mock_model_deploy: mock.Mock,
480482
mock_model_init: mock.Mock,
481483
mock_get_model_specs: mock.Mock,
@@ -507,8 +509,14 @@ def test_proprietary_model_endpoint(
507509
enable_network_isolation=False,
508510
)
509511

512+
model.register()
510513
model.deploy()
511514

515+
mock_model_register.assert_called_once_with(
516+
content_types=["application/json"],
517+
response_types=["application/json"],
518+
)
519+
512520
mock_model_deploy.assert_called_once_with(
513521
initial_instance_count=1,
514522
instance_type="ml.p4de.24xlarge",

0 commit comments

Comments
 (0)