Skip to content

Feature: register proprietary models from jumpstart #4753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def get_register_kwargs(
model_id: str,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
region: Optional[str] = None,
tolerate_deprecated_model: Optional[bool] = None,
tolerate_vulnerable_model: Optional[bool] = None,
Expand Down Expand Up @@ -720,13 +721,15 @@ def get_register_kwargs(
skip_model_validation: Optional[str] = None,
source_uri: Optional[str] = None,
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
accept_eula: Optional[bool] = None,
) -> JumpStartModelRegisterKwargs:
"""Returns kwargs required to call `register` on `sagemaker.estimator.Model` object."""

register_kwargs = JumpStartModelRegisterKwargs(
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
model_type=model_type,
region=region,
tolerate_deprecated_model=tolerate_deprecated_model,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down Expand Up @@ -754,12 +757,14 @@ def get_register_kwargs(
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
accept_eula=accept_eula,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
model_type=model_type,
region=region,
scope=JumpStartScriptScope.INFERENCE,
sagemaker_session=sagemaker_session,
Expand Down
14 changes: 13 additions & 1 deletion src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ def register(
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
accept_eula: Optional[bool] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -809,15 +810,25 @@ def register(
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).

accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).
Returns:
A `sagemaker.model.ModelPackage` instance.
"""

if model_package_group_name is None:
model_package_group_name = self.model_id
if self.model_type is JumpStartModelType.PROPRIETARY:
source_uri = self.model_package_arn

register_kwargs = get_register_kwargs(
model_id=self.model_id,
model_version=self.model_version,
hub_arn=self.hub_arn,
model_type=self.model_type,
region=self.region,
tolerate_deprecated_model=self.tolerate_deprecated_model,
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
Expand Down Expand Up @@ -845,6 +856,7 @@ def register(
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
accept_eula=accept_eula,
)

model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())
Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2372,6 +2372,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
"tolerate_deprecated_model",
"region",
"model_id",
"model_type",
"model_version",
"hub_arn",
"sagemaker_session",
Expand All @@ -2398,6 +2399,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
"skip_model_validation",
"source_uri",
"model_card",
"accept_eula",
]

SERIALIZATION_EXCLUSION_SET = {
Expand All @@ -2416,6 +2418,7 @@ def __init__(
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
tolerate_deprecated_model: Optional[bool] = None,
tolerate_vulnerable_model: Optional[bool] = None,
sagemaker_session: Optional[Any] = None,
Expand All @@ -2442,12 +2445,14 @@ def __init__(
skip_model_validation: Optional[str] = None,
source_uri: Optional[str] = None,
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
accept_eula: Optional[bool] = None,
) -> None:
"""Instantiates JumpStartModelRegisterKwargs object."""

self.model_id = model_id
self.model_version = model_version
self.hub_arn = hub_arn
self.model_type = model_type
self.region = region
self.image_uri = image_uri
self.sagemaker_session = sagemaker_session
Expand Down Expand Up @@ -2476,3 +2481,4 @@ def __init__(
self.skip_model_validation = skip_model_validation
self.source_uri = source_uri
self.model_card = model_card
self.accept_eula = accept_eula
6 changes: 4 additions & 2 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH,
load_sagemaker_config,
)
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.model_card import (
ModelCard,
ModelPackageModelCard,
Expand Down Expand Up @@ -448,6 +449,8 @@ def register(
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
accept_eula: Optional[bool] = None,
model_type: Optional[JumpStartModelType] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -522,9 +525,8 @@ def register(
model_package_group_name = utils.base_name_from_image(
self.image_uri, default_base_name=ModelPackage.__name__
)

if model_package_group_name is not None:
container_def = self.prepare_container_def()
container_def = self.prepare_container_def(accept_eula=accept_eula)
container_def = update_container_with_inference_params(
framework=framework,
framework_version=framework_version,
Expand Down
58 changes: 58 additions & 0 deletions tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def test_jumpstart_model_register(setup):

response = predictor.predict("hello world!")

predictor.delete_predictor()

assert response is not None


Expand All @@ -291,3 +293,59 @@ def test_proprietary_jumpstart_model(setup):
response = predictor.predict(payload)

assert response is not None


@pytest.mark.skipif(
True,
reason="Only enable if test account is subscribed to the proprietary model",
)
def test_register_proprietary_jumpstart_model(setup):

model_id = "ai21-jurassic-2-light"

model = JumpStartModel(
model_id=model_id,
model_version="2.0.004",
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
)
model_package = model.register()

predictor = model_package.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}]
)
payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1}

response = predictor.predict(payload)

predictor.delete_predictor()

assert response is not None


@pytest.mark.skipif(
True,
reason="Only enable if test account is subscribed to the proprietary model",
)
def test_register_gated_jumpstart_model(setup):

model_id = "meta-textgenerationneuron-llama-2-7b"
model = JumpStartModel(
model_id=model_id,
model_version="1.1.0",
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
)
model_package = model.register(accept_eula=True)

predictor = model_package.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
accept_eula=True,
)
payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1}

response = predictor.predict(payload)

predictor.delete_predictor()

assert response is not None
13 changes: 13 additions & 0 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,9 +473,11 @@ def test_eula_gated_conditional_s3_prefix_metadata_model(
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
@mock.patch("sagemaker.jumpstart.model.Model.register")
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
def test_proprietary_model_endpoint(
self,
mock_model_register: mock.Mock,
mock_model_deploy: mock.Mock,
mock_model_init: mock.Mock,
mock_get_model_specs: mock.Mock,
Expand Down Expand Up @@ -507,8 +509,17 @@ def test_proprietary_model_endpoint(
enable_network_isolation=False,
)

model.register()
model.deploy()

mock_model_register.assert_called_once_with(
model_type=JumpStartModelType.PROPRIETARY,
content_types=["application/json"],
response_types=["application/json"],
model_package_group_name=model_id,
source_uri=model.model_package_arn,
)

mock_model_deploy.assert_called_once_with(
initial_instance_count=1,
instance_type="ml.p4de.24xlarge",
Expand Down Expand Up @@ -1408,8 +1419,10 @@ def test_model_registry_accept_and_response_types(
model.register()

mock_model_register.assert_called_once_with(
model_type=JumpStartModelType.OPEN_WEIGHTS,
content_types=["application/x-text"],
response_types=["application/json;verbose", "application/json"],
model_package_group_name=model.model_id,
)

@mock.patch("sagemaker.jumpstart.model.get_default_predictor")
Expand Down