Skip to content

Commit 4656cb8

Browse files
committed
chore: raise client exception if accept_eula flag is not set for gated models
1 parent c554a9f commit 4656cb8

File tree

6 files changed

+54
-49
lines changed

6 files changed

+54
-49
lines changed

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _retrieve_default_environment_variables(
120120
instance_type=instance_type,
121121
)
122122

123-
if gated_model_env_var is None and model_specs.gated_bucket:
123+
if gated_model_env_var is None and model_specs.is_gated_model():
124124
raise ValueError(
125125
f"'{model_id}' does not support {instance_type} instance type for training. "
126126
"Please use one of the following instance types: "

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
)
6363
from sagemaker.jumpstart.utils import (
6464
add_jumpstart_model_id_version_tags,
65+
get_eula_message,
6566
update_dict_if_key_not_present,
6667
resolve_estimator_sagemaker_config_field,
6768
verify_model_region_and_return_specs,
@@ -595,6 +596,25 @@ def _add_env_to_kwargs(
595596
value,
596597
)
597598

599+
environment = getattr(kwargs, "environment", {}) or {}
600+
if (
601+
environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY)
602+
and str(environment.get("accept_eula", "")).lower() != "true"
603+
):
604+
model_specs = verify_model_region_and_return_specs(
605+
model_id=kwargs.model_id,
606+
version=kwargs.model_version,
607+
region=kwargs.region,
608+
scope=JumpStartScriptScope.TRAINING,
609+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
610+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
611+
sagemaker_session=kwargs.sagemaker_session,
612+
)
613+
if model_specs.is_gated_model():
614+
raise ValueError(
615+
f"Need to define ‘accept_eula'='true' within Environment. {get_eula_message(model_specs, kwargs.region)}"
616+
)
617+
598618
return kwargs
599619

600620

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,10 @@ def use_training_model_artifact(self) -> bool:
956956
# otherwise, return true is a training model package is not set
957957
return len(self.training_model_package_artifact_uris or {}) == 0
958958

959+
def is_gated_model(self) -> bool:
960+
"""Returns True if the model has a EULA key or the model bucket is gated."""
961+
return self.gated_bucket or self.hosting_eula_key is not None
962+
959963
def supports_incremental_training(self) -> bool:
960964
"""Returns True if the model supports incremental training."""
961965
return self.incremental_training_supported

src/sagemaker/jumpstart/utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -466,21 +466,25 @@ def update_inference_tags_with_jumpstart_training_tags(
466466
return inference_tags
467467

468468

469+
def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
470+
"""Returns EULA message to display to customers if one is available, else empty string."""
471+
if model_specs.hosting_eula_key is None:
472+
return ""
473+
return (
474+
f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). "
475+
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
476+
f"amazonaws.com{'.cn' if region.startswith('cn-') else ''}"
477+
f"/{model_specs.hosting_eula_key} for terms of use."
478+
)
479+
480+
469481
def emit_logs_based_on_model_specs(
470482
model_specs: JumpStartModelSpecs, region: str, s3_client: boto3.client
471483
) -> None:
472484
"""Emits logs based on model specs and region."""
473485

474486
if model_specs.hosting_eula_key:
475-
constants.JUMPSTART_LOGGER.info(
476-
"Model '%s' requires accepting end-user license agreement (EULA). "
477-
"See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.",
478-
model_specs.model_id,
479-
get_jumpstart_content_bucket(region=region),
480-
region,
481-
".cn" if region.startswith("cn-") else "",
482-
model_specs.hosting_eula_key,
483-
)
487+
constants.JUMPSTART_LOGGER.info(get_eula_message(model_specs, region))
484488

485489
full_version: str = model_specs.version
486490

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -315,39 +315,21 @@ def test_gated_model_s3_uri(
315315
mock_session_estimator.return_value = sagemaker_session
316316
mock_session_model.return_value = sagemaker_session
317317

318-
JumpStartEstimator(
319-
model_id=model_id,
320-
environment={
321-
"accept_eula": "false",
322-
"what am i": "doing",
323-
"SageMakerGatedModelS3Uri": "none of your business",
324-
},
325-
)
326-
327-
mock_estimator_init.assert_called_once_with(
328-
instance_type="ml.p3.2xlarge",
329-
instance_count=1,
330-
image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117",
331-
source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/"
332-
"meta/transfer_learning/textgeneration/v1.0.0/sourcedir.tar.gz",
333-
entry_point="transfer_learning.py",
334-
role=execution_role,
335-
sagemaker_session=sagemaker_session,
336-
max_run=360000,
337-
enable_network_isolation=True,
338-
encrypt_inter_container_traffic=True,
339-
environment={
340-
"accept_eula": "false",
341-
"what am i": "doing",
342-
"SageMakerGatedModelS3Uri": "none of your business",
343-
},
344-
tags=[
345-
{
346-
"Key": "sagemaker-sdk:jumpstart-model-id",
347-
"Value": "js-gated-artifact-trainable-model",
318+
with pytest.raises(ValueError) as e:
319+
JumpStartEstimator(
320+
model_id=model_id,
321+
environment={
322+
"accept_eula": "false",
323+
"what am i": "doing",
324+
"SageMakerGatedModelS3Uri": "none of your business",
348325
},
349-
{"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.0"},
350-
],
326+
)
327+
assert str(e.value) == (
328+
"Need to define ‘accept_eula'='true' within Environment. "
329+
"Model 'meta-textgeneration-llama-2-7b-f' requires accepting end-user "
330+
"license agreement (EULA). See "
331+
"https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com/fmhMetadata/eula/llamaEula.txt"
332+
" for terms of use."
351333
)
352334

353335
mock_estimator_init.reset_mock()

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -905,13 +905,8 @@ def make_accept_eula_inference_spec(*largs, **kwargs):
905905
make_accept_eula_inference_spec(), "us-east-1", MOCK_CLIENT
906906
)
907907
mocked_info_log.assert_any_call(
908-
"Model '%s' requires accepting end-user license agreement (EULA). "
909-
"See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.",
910-
"pytorch-eqa-bert-base-cased",
911-
"jumpstart-cache-prod-us-east-1",
912-
"us-east-1",
913-
"",
914-
"read/the/fine/print.txt",
908+
"Model 'pytorch-eqa-bert-base-cased' requires accepting end-user license agreement (EULA). "
909+
"See https://jumpstart-cache-prod-us-east-1.s3.us-east-1.amazonaws.com/read/the/fine/print.txt for terms of use.",
915910
)
916911

917912

0 commit comments

Comments
 (0)