Skip to content

Commit e8c867e

Browse files
evakravijiapinw
authored andcommitted
chore: emit warning when no instance specific gated training env var is available, and raise exception when accept_eula flag is not supplied (aws#4485)
* fix: raise exception when no instance specific gated training env var available * chore: raise client exception if accept_eula flag is not set for gated models * chore: address flake8 errors * chore: emit warning when instance type is chosen with no gated training artifacts
1 parent d32067d commit e8c867e

File tree

8 files changed

+1388
-50
lines changed

8 files changed

+1388
-50
lines changed

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains functions for obtaining JumpStart environment variables."""
1414
from __future__ import absolute_import
15-
from typing import Dict, Optional
15+
from typing import Callable, Dict, Optional, Set
1616
from sagemaker.jumpstart.constants import (
1717
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1818
JUMPSTART_DEFAULT_REGION_NAME,
19+
JUMPSTART_LOGGER,
1920
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
2021
)
2122
from sagemaker.jumpstart.enums import (
@@ -110,7 +111,9 @@ def _retrieve_default_environment_variables(
110111

111112
default_environment_variables.update(instance_specific_environment_variables)
112113

113-
gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value(
114+
retrieve_gated_env_var_for_instance_type: Callable[
115+
[str], Optional[str]
116+
] = lambda instance_type: _retrieve_gated_model_uri_env_var_value(
114117
model_id=model_id,
115118
model_version=model_version,
116119
region=region,
@@ -120,6 +123,33 @@ def _retrieve_default_environment_variables(
120123
instance_type=instance_type,
121124
)
122125

126+
gated_model_env_var: Optional[str] = retrieve_gated_env_var_for_instance_type(
127+
instance_type
128+
)
129+
130+
if gated_model_env_var is None and model_specs.is_gated_model():
131+
132+
possible_env_vars: Set[str] = {
133+
retrieve_gated_env_var_for_instance_type(instance_type)
134+
for instance_type in model_specs.supported_training_instance_types
135+
}
136+
137+
# If all officially supported instance types have the same underlying artifact,
138+
# we can use this artifact with high confidence that it'll succeed with
139+
# an arbitrary instance.
140+
if len(possible_env_vars) == 1:
141+
gated_model_env_var = list(possible_env_vars)[0]
142+
143+
# If this model does not have 1 artifact for all supported instance types,
144+
# we cannot determine which artifact to use for an arbitrary instance.
145+
else:
146+
log_msg = (
147+
f"'{model_id}' does not support {instance_type} instance type"
148+
" for training. Please use one of the following instance types: "
149+
f"{', '.join(model_specs.supported_training_instance_types)}."
150+
)
151+
JUMPSTART_LOGGER.warning(log_msg)
152+
123153
if gated_model_env_var is not None:
124154
default_environment_variables.update(
125155
{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: gated_model_env_var}

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
)
6363
from sagemaker.jumpstart.utils import (
6464
add_jumpstart_model_id_version_tags,
65+
get_eula_message,
6566
update_dict_if_key_not_present,
6667
resolve_estimator_sagemaker_config_field,
6768
verify_model_region_and_return_specs,
@@ -597,6 +598,26 @@ def _add_env_to_kwargs(
597598
value,
598599
)
599600

601+
environment = getattr(kwargs, "environment", {}) or {}
602+
if (
603+
environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY)
604+
and str(environment.get("accept_eula", "")).lower() != "true"
605+
):
606+
model_specs = verify_model_region_and_return_specs(
607+
model_id=kwargs.model_id,
608+
version=kwargs.model_version,
609+
region=kwargs.region,
610+
scope=JumpStartScriptScope.TRAINING,
611+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
612+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
613+
sagemaker_session=kwargs.sagemaker_session,
614+
)
615+
if model_specs.is_gated_model():
616+
raise ValueError(
617+
"Need to define ‘accept_eula'='true' within Environment. "
618+
f"{get_eula_message(model_specs, kwargs.region)}"
619+
)
620+
600621
return kwargs
601622

602623

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,10 @@ def use_training_model_artifact(self) -> bool:
963963
# otherwise, return true is a training model package is not set
964964
return len(self.training_model_package_artifact_uris or {}) == 0
965965

966+
def is_gated_model(self) -> bool:
967+
"""Returns True if the model has a EULA key or the model bucket is gated."""
968+
return self.gated_bucket or self.hosting_eula_key is not None
969+
966970
def supports_incremental_training(self) -> bool:
967971
"""Returns True if the model supports incremental training."""
968972
return self.incremental_training_supported

src/sagemaker/jumpstart/utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -476,21 +476,25 @@ def update_inference_tags_with_jumpstart_training_tags(
476476
return inference_tags
477477

478478

479+
def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
480+
"""Returns EULA message to display if one is available, else empty string."""
481+
if model_specs.hosting_eula_key is None:
482+
return ""
483+
return (
484+
f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). "
485+
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
486+
f"amazonaws.com{'.cn' if region.startswith('cn-') else ''}"
487+
f"/{model_specs.hosting_eula_key} for terms of use."
488+
)
489+
490+
479491
def emit_logs_based_on_model_specs(
480492
model_specs: JumpStartModelSpecs, region: str, s3_client: boto3.client
481493
) -> None:
482494
"""Emits logs based on model specs and region."""
483495

484496
if model_specs.hosting_eula_key:
485-
constants.JUMPSTART_LOGGER.info(
486-
"Model '%s' requires accepting end-user license agreement (EULA). "
487-
"See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.",
488-
model_specs.model_id,
489-
get_jumpstart_content_bucket(region=region),
490-
region,
491-
".cn" if region.startswith("cn-") else "",
492-
model_specs.hosting_eula_key,
493-
)
497+
constants.JUMPSTART_LOGGER.info(get_eula_message(model_specs, region))
494498

495499
full_version: str = model_specs.version
496500

tests/unit/sagemaker/environment_variables/jumpstart/test_default.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919

2020
from sagemaker import environment_variables
21+
from sagemaker.jumpstart.utils import get_jumpstart_gated_content_bucket
2122
from sagemaker.jumpstart.enums import JumpStartModelType
2223

2324
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
@@ -203,6 +204,70 @@ def test_jumpstart_sdk_environment_variables(
203204
)
204205

205206

207+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
208+
def test_jumpstart_sdk_environment_variables_1_artifact_all_variants(patched_get_model_specs):
209+
210+
patched_get_model_specs.side_effect = get_special_model_spec
211+
212+
model_id = "gemma-model-1-artifact"
213+
region = "us-west-2"
214+
215+
assert {
216+
"SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/"
217+
"huggingface-training/train-huggingface-llm-gemma-7b-instruct.tar.gz"
218+
} == environment_variables.retrieve_default(
219+
region=region,
220+
model_id=model_id,
221+
model_version="*",
222+
include_aws_sdk_env_vars=False,
223+
sagemaker_session=mock_session,
224+
instance_type="ml.p3.2xlarge",
225+
script="training",
226+
)
227+
228+
229+
@patch("sagemaker.jumpstart.artifacts.environment_variables.JUMPSTART_LOGGER")
230+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
231+
def test_jumpstart_sdk_environment_variables_no_gated_env_var_available(
232+
patched_get_model_specs, patched_jumpstart_logger
233+
):
234+
235+
patched_get_model_specs.side_effect = get_special_model_spec
236+
237+
model_id = "gemma-model"
238+
region = "us-west-2"
239+
240+
assert {} == environment_variables.retrieve_default(
241+
region=region,
242+
model_id=model_id,
243+
model_version="*",
244+
include_aws_sdk_env_vars=False,
245+
sagemaker_session=mock_session,
246+
instance_type="ml.p3.2xlarge",
247+
script="training",
248+
)
249+
250+
patched_jumpstart_logger.warning.assert_called_once_with(
251+
"'gemma-model' does not support ml.p3.2xlarge instance type for "
252+
"training. Please use one of the following instance types: "
253+
"ml.g5.12xlarge, ml.g5.24xlarge, ml.g5.48xlarge, ml.p4d.24xlarge."
254+
)
255+
256+
# assert that supported instance types succeed
257+
assert {
258+
"SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/"
259+
"huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-7b-instruct.tar.gz"
260+
} == environment_variables.retrieve_default(
261+
region=region,
262+
model_id=model_id,
263+
model_version="*",
264+
include_aws_sdk_env_vars=False,
265+
sagemaker_session=mock_session,
266+
instance_type="ml.g5.24xlarge",
267+
script="training",
268+
)
269+
270+
206271
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
207272
def test_jumpstart_sdk_environment_variables_instance_type_overrides(patched_get_model_specs):
208273

0 commit comments

Comments
 (0)