diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 55dfa1394a..a231ab917c 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -693,6 +693,7 @@ def get_register_kwargs( model_id: str, model_version: Optional[str] = None, hub_arn: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, @@ -720,6 +721,7 @@ def get_register_kwargs( skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None, + accept_eula: Optional[bool] = None, ) -> JumpStartModelRegisterKwargs: """Returns kwargs required to call `register` on `sagemaker.estimator.Model` object.""" @@ -727,6 +729,7 @@ def get_register_kwargs( model_id=model_id, model_version=model_version, hub_arn=hub_arn, + model_type=model_type, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -754,12 +757,14 @@ def get_register_kwargs( skip_model_validation=skip_model_validation, source_uri=source_uri, model_card=model_card, + accept_eula=accept_eula, ) model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, hub_arn=hub_arn, + model_type=model_type, region=region, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=sagemaker_session, diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index e99cbcc57a..b482d4fefd 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -760,6 +760,7 @@ def register( skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + accept_eula: Optional[bool] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -809,15 +810,25 @@ def register( (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). - + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: A `sagemaker.model.ModelPackage` instance. """ + if model_package_group_name is None: + model_package_group_name = self.model_id + if self.model_type is JumpStartModelType.PROPRIETARY: + source_uri = self.model_package_arn + register_kwargs = get_register_kwargs( model_id=self.model_id, model_version=self.model_version, hub_arn=self.hub_arn, + model_type=self.model_type, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, @@ -845,6 +856,7 @@ def register( skip_model_validation=skip_model_validation, source_uri=source_uri, model_card=model_card, + accept_eula=accept_eula, ) model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 420505c508..171d9ce8a1 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2372,6 +2372,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "tolerate_deprecated_model", "region", "model_id", + "model_type", "model_version", "hub_arn", "sagemaker_session", @@ -2398,6 +2399,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "skip_model_validation", "source_uri", "model_card", + "accept_eula", ] SERIALIZATION_EXCLUSION_SET = { @@ -2416,6 +2418,7 @@ def __init__( model_version: Optional[str] = None, hub_arn: Optional[str] = None, region: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, sagemaker_session: Optional[Any] = None, @@ -2442,12 +2445,14 @@ def __init__( skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None, + accept_eula: Optional[bool] = None, ) -> None: """Instantiates JumpStartModelRegisterKwargs object.""" self.model_id = model_id self.model_version = model_version self.hub_arn = hub_arn + self.model_type = model_type self.region = region self.image_uri = image_uri self.sagemaker_session = sagemaker_session @@ -2476,3 +2481,4 @@ def __init__( self.skip_model_validation = skip_model_validation self.source_uri = source_uri self.model_card = model_card + self.accept_eula = accept_eula diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 09e04ad840..4befd8cd96 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -44,6 +44,7 @@ ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH, load_sagemaker_config, ) +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.model_card import ( ModelCard, ModelPackageModelCard, @@ -448,6 +449,8 @@ def register( skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + accept_eula: Optional[bool] = None, + model_type: Optional[JumpStartModelType] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -522,9 +525,8 @@ def register( model_package_group_name = utils.base_name_from_image( self.image_uri, default_base_name=ModelPackage.__name__ ) - if model_package_group_name is not None: - container_def = self.prepare_container_def() + container_def = self.prepare_container_def(accept_eula=accept_eula) container_def = update_container_with_inference_params( framework=framework, framework_version=framework_version, diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 96ee82883e..6bc0a5c996 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -265,6 +265,8 @@ def test_jumpstart_model_register(setup): response = predictor.predict("hello world!") + predictor.delete_predictor() + assert response is not None @@ -291,3 +293,59 @@ def test_proprietary_jumpstart_model(setup): response = predictor.predict(payload) assert response is not None + + +@pytest.mark.skipif( + True, + reason="Only enable if test account is subscribed to the proprietary model", +) +def test_register_proprietary_jumpstart_model(setup): + + model_id = "ai21-jurassic-2-light" + + model = JumpStartModel( + model_id=model_id, + model_version="2.0.004", + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + model_package = model.register() + + predictor = model_package.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}] + ) + payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1} + + response = predictor.predict(payload) + + predictor.delete_predictor() + + assert response is not None + + +@pytest.mark.skipif( + True, + reason="Only enable if test account is subscribed to the proprietary model", +) +def test_register_gated_jumpstart_model(setup): + + model_id = "meta-textgenerationneuron-llama-2-7b" + model = JumpStartModel( + model_id=model_id, + model_version="1.1.0", + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + model_package = model.register(accept_eula=True) + + predictor = model_package.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + accept_eula=True, + ) + payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1} + + response = predictor.predict(payload) + + predictor.delete_predictor() + + assert response is not None diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 90a2c573d9..15c2c43bf0 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -473,9 +473,11 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.model.Model.register") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_proprietary_model_endpoint( self, + mock_model_register: mock.Mock, mock_model_deploy: mock.Mock, mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, @@ -507,8 +509,17 @@ def test_proprietary_model_endpoint( enable_network_isolation=False, ) + model.register() model.deploy() + mock_model_register.assert_called_once_with( + model_type=JumpStartModelType.PROPRIETARY, + content_types=["application/json"], + response_types=["application/json"], + model_package_group_name=model_id, + source_uri=model.model_package_arn, + ) + mock_model_deploy.assert_called_once_with( initial_instance_count=1, instance_type="ml.p4de.24xlarge", @@ -1408,8 +1419,10 @@ def test_model_registry_accept_and_response_types( model.register() mock_model_register.assert_called_once_with( + model_type=JumpStartModelType.OPEN_WEIGHTS, content_types=["application/x-text"], response_types=["application/json;verbose", "application/json"], + model_package_group_name=model.model_id, ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor")