Skip to content

Commit d1ee7d2

Browse files
committed
linting
1 parent c844b33 commit d1ee7d2

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

src/sagemaker/jumpstart/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
validate_model_id_and_get_type,
4242
resolve_model_sagemaker_config_field,
4343
verify_model_region_and_return_specs,
44-
remove_env_var_from_estimator_kwargs_if_accept_eula_present,
44+
remove_env_var_from_estimator_kwargs_if_model_access_config_present,
4545
get_model_access_config,
4646
get_hub_access_config,
4747
)
@@ -714,7 +714,9 @@ def fit(
714714
config_name=self.config_name,
715715
hub_access_config=self.hub_access_config,
716716
)
717-
remove_env_var_from_estimator_kwargs_if_accept_eula_present(self.init_kwargs, accept_eula)
717+
remove_env_var_from_estimator_kwargs_if_model_access_config_present(
718+
self.init_kwargs, self.model_access_config
719+
)
718720

719721
return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict())
720722

src/sagemaker/jumpstart/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,16 +1632,16 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
16321632
return neo_bucket
16331633

16341634

1635-
def remove_env_var_from_estimator_kwargs_if_accept_eula_present(
1636-
init_kwargs: dict, accept_eula: Optional[bool]
1635+
def remove_env_var_from_estimator_kwargs_if_model_access_config_present(
1636+
init_kwargs: dict, model_access_config: dict | None
16371637
):
1638-
"""Remove env vars if access configs are used
1638+
"""Remove env vars if ModelAccessConfig is used
16391639
16401640
Args:
16411641
init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated.
16421642
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
16431643
"""
1644-
if accept_eula is not None and init_kwargs.get("environment") is not None:
1644+
if model_access_config is not None and init_kwargs.get("environment") is not None:
16451645
if (
16461646
constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY
16471647
in init_kwargs["environment"]
@@ -1672,7 +1672,7 @@ def get_model_access_config(accept_eula: Optional[bool], environment: Optional[d
16721672
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
16731673
"""
16741674
env_var_eula = environment.get("accept_eula") if environment else None
1675-
if env_var_eula and accept_eula is not None:
1675+
if env_var_eula is not None and accept_eula is not None:
16761676
raise ValueError(
16771677
"Cannot pass in both accept_eula and environment variables. "
16781678
"Please remove the environment variable and pass in the accept_eula parameter."

0 commit comments

Comments
 (0)