Skip to content

Commit cacb977

Browse files
committed
fix jumpstart curated hub bugs
1 parent e262db9 commit cacb977

File tree

4 files changed

+61
-6
lines changed

4 files changed

+61
-6
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.jumpstart.hub.utils import (
2626
construct_hub_model_arn_from_inputs,
2727
construct_hub_model_reference_arn_from_inputs,
28+
generate_hub_arn_for_init_kwargs,
2829
)
2930
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
3031
from sagemaker.session import Session
@@ -291,6 +292,10 @@ def get_model_specs(
291292
# Users only input model id, not contentType, so first try to describe with ModelReference, then with Model
292293
if hub_arn:
293294
try:
295+
hub_arn = generate_hub_arn_for_init_kwargs(
296+
hub_name=hub_arn, region=region, session=sagemaker_session
297+
)
298+
294299
hub_model_arn = construct_hub_model_reference_arn_from_inputs(
295300
hub_arn=hub_arn, model_name=model_id, version=version
296301
)

src/sagemaker/jumpstart/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ def fit(
693693
accept the end-user license agreement (EULA) that some
694694
models require. (Default: None).
695695
"""
696-
self.model_access_config = get_model_access_config(accept_eula)
696+
self.model_access_config = get_model_access_config(accept_eula, self.environment)
697697
self.hub_access_config = get_hub_access_config(
698698
hub_content_arn=self.init_kwargs.get("model_reference_arn", None)
699699
)

src/sagemaker/jumpstart/utils.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,8 +1641,14 @@ def remove_env_var_from_estimator_kwargs_if_accept_eula_present(
16411641
init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated.
16421642
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
16431643
"""
1644-
if accept_eula is not None and init_kwargs["environment"]:
1645-
del init_kwargs["environment"][constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY]
1644+
if accept_eula is not None and init_kwargs.get("environment") is not None:
1645+
if (
1646+
constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY
1647+
in init_kwargs["environment"]
1648+
):
1649+
del init_kwargs["environment"][
1650+
constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY
1651+
]
16461652

16471653

16481654
def get_hub_access_config(hub_content_arn: Optional[str]):
@@ -1659,16 +1665,24 @@ def get_hub_access_config(hub_content_arn: Optional[str]):
16591665
return hub_access_config
16601666

16611667

1662-
def get_model_access_config(accept_eula: Optional[bool]):
1668+
def get_model_access_config(accept_eula: Optional[bool], environment: Optional[dict]):
16631669
"""Get access configs
16641670
16651671
Args:
16661672
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
16671673
"""
1674+
env_var_eula = environment.get("accept_eula")
1675+
if env_var_eula and accept_eula is not None:
1676+
raise ValueError(
1677+
"Cannot pass in both accept_eula and environment variables. "
1678+
"Please remove the environment variable and pass in the accept_eula parameter."
1679+
)
1680+
1681+
model_access_config = None
1682+
if env_var_eula is not None:
1683+
model_access_config = {"AcceptEula": True if env_var_eula == "true" else False}
16681684
if accept_eula is not None:
16691685
model_access_config = {"AcceptEula": accept_eula}
1670-
else:
1671-
model_access_config = None
16721686

16731687
return model_access_config
16741688

tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,42 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):
170170
assert response is not None
171171

172172

173+
def test_jumpstart_hub_gated_estimator_with_eula_env_var(setup, add_model_references):
174+
175+
model_id, model_version = "meta-textgeneration-llama-2-7b", "*"
176+
177+
estimator = JumpStartEstimator(
178+
model_id=model_id,
179+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
180+
environment={
181+
"accept_eula": "true",
182+
},
183+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
184+
)
185+
186+
estimator.fit(
187+
inputs={
188+
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
189+
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
190+
},
191+
)
192+
193+
predictor = estimator.deploy(
194+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
195+
role=get_sm_session().get_caller_identity_arn(),
196+
sagemaker_session=get_sm_session(),
197+
)
198+
199+
payload = {
200+
"inputs": "some-payload",
201+
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
202+
}
203+
204+
response = predictor.predict(payload, custom_attributes="accept_eula=true")
205+
206+
assert response is not None
207+
208+
173209
def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references):
174210

175211
model_id, model_version = "meta-textgeneration-llama-2-7b", "*"

0 commit comments

Comments
 (0)