Skip to content

Commit 8ba576a

Browse files
committed
pass hub_arn into all estimator utils/artifacts
1 parent bb7a9fb commit 8ba576a

24 files changed

+120
-15
lines changed

src/sagemaker/environment_variables.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def retrieve_default(
3030
region: Optional[str] = None,
3131
model_id: Optional[str] = None,
3232
model_version: Optional[str] = None,
33+
hub_arn: Optional[str] = None,
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
include_aws_sdk_env_vars: bool = True,
@@ -46,6 +47,8 @@ def retrieve_default(
4647
retrieve the default environment variables. (Default: None).
4748
model_version (str): Optional. The version of the model for which to retrieve the
4849
default environment variables. (Default: None).
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+
model details from (default: None).
4952
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5053
specifications should be tolerated (exception not raised). If False, raises an
5154
exception if the script used by this version of the model has dependencies with known
@@ -80,6 +83,7 @@ def retrieve_default(
8083
return artifacts._retrieve_default_environment_variables(
8184
model_id,
8285
model_version,
86+
hub_arn,
8387
region,
8488
tolerate_vulnerable_model,
8589
tolerate_deprecated_model,

src/sagemaker/hyperparameters.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def retrieve_default(
3131
region: Optional[str] = None,
3232
model_id: Optional[str] = None,
3333
model_version: Optional[str] = None,
34+
hub_arn: Optional[str] = None,
3435
instance_type: Optional[str] = None,
3536
include_container_hyperparameters: bool = False,
3637
tolerate_vulnerable_model: bool = False,
@@ -46,6 +47,8 @@ def retrieve_default(
4647
retrieve the default hyperparameters. (Default: None).
4748
model_version (str): The version of the model for which to retrieve the
4849
default hyperparameters. (Default: None).
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+
model details from (default: None).
4952
instance_type (str): An instance type to optionally supply in order to get hyperparameters
5053
specific for the instance type.
5154
include_container_hyperparameters (bool): ``True`` if the container hyperparameters
@@ -80,6 +83,7 @@ def retrieve_default(
8083
return artifacts._retrieve_default_hyperparameters(
8184
model_id=model_id,
8285
model_version=model_version,
86+
hub_arn=hub_arn,
8387
instance_type=instance_type,
8488
region=region,
8589
include_container_hyperparameters=include_container_hyperparameters,
@@ -93,6 +97,7 @@ def validate(
9397
region: Optional[str] = None,
9498
model_id: Optional[str] = None,
9599
model_version: Optional[str] = None,
100+
hub_arn: Optional[str] = None,
96101
hyperparameters: Optional[dict] = None,
97102
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
98103
tolerate_vulnerable_model: bool = False,
@@ -107,6 +112,8 @@ def validate(
107112
(Default: None).
108113
model_version (str): The version of the model for which to validate hyperparameters.
109114
(Default: None).
115+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
116+
model details from (default: None).
110117
hyperparameters (dict): Hyperparameters to validate.
111118
(Default: None).
112119
validation_mode (HyperparameterValidationMode): Method of validation to use with
@@ -148,6 +155,7 @@ def validate(
148155
return validate_hyperparameters(
149156
model_id=model_id,
150157
model_version=model_version,
158+
hub_arn=hub_arn,
151159
hyperparameters=hyperparameters,
152160
validation_mode=validation_mode,
153161
region=region,

src/sagemaker/image_uris.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def retrieve(
6161
training_compiler_config=None,
6262
model_id=None,
6363
model_version=None,
64+
hub_arn=None,
6465
tolerate_vulnerable_model=False,
6566
tolerate_deprecated_model=False,
6667
sdk_version=None,
@@ -101,6 +102,8 @@ def retrieve(
101102
(default: None).
102103
model_version (str): The version of the JumpStart model for which to retrieve the
103104
image URI (default: None).
105+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
106+
model details from (default: None).
104107
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications
105108
should be tolerated without an exception raised. If ``False``, raises an exception if
106109
the script used by this version of the model has dependencies with known security
@@ -146,6 +149,7 @@ def retrieve(
146149
model_id,
147150
model_version,
148151
image_scope,
152+
hub_arn,
149153
framework,
150154
region,
151155
version,

src/sagemaker/instance_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def retrieve_default(
4545
retrieve the default instance type. (Default: None).
4646
model_version (str): The version of the model for which to retrieve the
4747
default instance type. (Default: None).
48+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
49+
model details from (default: None).
4850
scope (str): The model type, i.e. what it is used for.
4951
Valid values: "training" and "inference".
5052
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -110,6 +112,8 @@ def retrieve(
110112
retrieve the supported instance types. (Default: None).
111113
model_version (str): The version of the model for which to retrieve the
112114
supported instance types. (Default: None).
115+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
116+
model details from (default: None).
113117
tolerate_vulnerable_model (bool): True if vulnerable versions of model
114118
specifications should be tolerated (exception not raised). If False, raises an
115119
exception if the script used by this version of the model has dependencies with known

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
def _retrieve_default_environment_variables(
3232
model_id: str,
3333
model_version: str,
34+
hub_arn: Optional[str] = None,
3435
region: Optional[str] = None,
3536
tolerate_vulnerable_model: bool = False,
3637
tolerate_deprecated_model: bool = False,
@@ -46,6 +47,8 @@ def _retrieve_default_environment_variables(
4647
retrieve the default environment variables.
4748
model_version (str): Version of the JumpStart model for which to retrieve the
4849
default environment variables.
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+
model details from (default: None).
4952
region (Optional[str]): Region for which to retrieve default environment variables.
5053
(Default: None).
5154
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -77,6 +80,7 @@ def _retrieve_default_environment_variables(
7780
model_specs = verify_model_region_and_return_specs(
7881
model_id=model_id,
7982
version=model_version,
83+
hub_arn=hub_arn,
8084
scope=script,
8185
region=region,
8286
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -113,6 +117,7 @@ def _retrieve_default_environment_variables(
113117
gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value(
114118
model_id=model_id,
115119
model_version=model_version,
120+
hub_arn=hub_arn,
116121
region=region,
117122
tolerate_vulnerable_model=tolerate_vulnerable_model,
118123
tolerate_deprecated_model=tolerate_deprecated_model,
@@ -131,6 +136,7 @@ def _retrieve_default_environment_variables(
131136
def _retrieve_gated_model_uri_env_var_value(
132137
model_id: str,
133138
model_version: str,
139+
hub_arn: Optional[str] = None,
134140
region: Optional[str] = None,
135141
tolerate_vulnerable_model: bool = False,
136142
tolerate_deprecated_model: bool = False,
@@ -144,6 +150,8 @@ def _retrieve_gated_model_uri_env_var_value(
144150
retrieve the gated model env var URI.
145151
model_version (str): Version of the JumpStart model for which to retrieve the
146152
gated model env var URI.
153+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
154+
model details from (default: None).
147155
region (Optional[str]): Region for which to retrieve the gated model env var URI.
148156
(Default: None).
149157
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -174,6 +182,7 @@ def _retrieve_gated_model_uri_env_var_value(
174182
model_specs = verify_model_region_and_return_specs(
175183
model_id=model_id,
176184
version=model_version,
185+
hub_arn=hub_arn,
177186
scope=JumpStartScriptScope.TRAINING,
178187
region=region,
179188
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/hyperparameters.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
def _retrieve_default_hyperparameters(
3131
model_id: str,
3232
model_version: str,
33+
hub_arn: Optional[str] = None,
3334
region: Optional[str] = None,
3435
include_container_hyperparameters: bool = False,
3536
tolerate_vulnerable_model: bool = False,
@@ -44,6 +45,8 @@ def _retrieve_default_hyperparameters(
4445
retrieve the default hyperparameters.
4546
model_version (str): Version of the JumpStart model for which to retrieve the
4647
default hyperparameters.
48+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
49+
model details from (default: None).
4750
region (str): Region for which to retrieve default hyperparameters.
4851
(Default: None).
4952
include_container_hyperparameters (bool): True if container hyperparameters
@@ -76,6 +79,7 @@ def _retrieve_default_hyperparameters(
7679
model_specs = verify_model_region_and_return_specs(
7780
model_id=model_id,
7881
version=model_version,
82+
hub_arn=hub_arn,
7983
scope=JumpStartScriptScope.TRAINING,
8084
region=region,
8185
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/image_uris.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _retrieve_image_uri(
3333
model_id: str,
3434
model_version: str,
3535
image_scope: str,
36+
hub_arn: Optional[str] = None,
3637
framework: Optional[str] = None,
3738
region: Optional[str] = None,
3839
version: Optional[str] = None,
@@ -57,6 +58,8 @@ def _retrieve_image_uri(
5758
model_id (str): JumpStart model ID for which to retrieve image URI.
5859
model_version (str): Version of the JumpStart model for which to retrieve
5960
the image URI.
61+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
62+
model details from (default: None).
6063
image_scope (str): The image type, i.e. what it is used for.
6164
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
6265
``image_scope`` is ignored.
@@ -110,6 +113,7 @@ def _retrieve_image_uri(
110113
model_specs = verify_model_region_and_return_specs(
111114
model_id=model_id,
112115
version=model_version,
116+
hub_arn=hub_arn,
113117
scope=image_scope,
114118
region=region,
115119
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/incremental_training.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def _model_supports_incremental_training(
3030
model_id: str,
3131
model_version: str,
3232
region: Optional[str],
33+
hub_arn: Optional[str] = None,
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -43,6 +44,8 @@ def _model_supports_incremental_training(
4344
support status for incremental training.
4445
region (Optional[str]): Region for which to retrieve the
4546
support status for incremental training.
47+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
48+
model details from (default: None).
4649
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4750
specifications should be tolerated (exception not raised). If False, raises an
4851
exception if the script used by this version of the model has dependencies with known
@@ -64,6 +67,7 @@ def _model_supports_incremental_training(
6467
model_specs = verify_model_region_and_return_specs(
6568
model_id=model_id,
6669
version=model_version,
70+
hub_arn=hub_arn,
6771
scope=JumpStartScriptScope.TRAINING,
6872
region=region,
6973
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/instance_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def _retrieve_default_instance_type(
4949
default instance type.
5050
scope (str): The script type, i.e. what it is used for.
5151
Valid values: "training" and "inference".
52+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
53+
model details from (default: None).
5254
region (Optional[str]): Region for which to retrieve default instance type.
5355
(Default: None).
5456
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -137,6 +139,8 @@ def _retrieve_instance_types(
137139
supported instance types.
138140
scope (str): The script type, i.e. what it is used for.
139141
Valid values: "training" and "inference".
142+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
143+
model details from (default: None).
140144
region (Optional[str]): Region for which to retrieve supported instance types.
141145
(Default: None).
142146
tolerate_vulnerable_model (bool): True if vulnerable versions of model

src/sagemaker/jumpstart/artifacts/kwargs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def _retrieve_estimator_init_kwargs(
140140
model_id: str,
141141
model_version: str,
142142
instance_type: str,
143+
hub_arn: Optional[str] = None,
143144
region: Optional[str] = None,
144145
tolerate_vulnerable_model: bool = False,
145146
tolerate_deprecated_model: bool = False,
@@ -154,6 +155,8 @@ def _retrieve_estimator_init_kwargs(
154155
kwargs.
155156
instance_type (str): Instance type of the training job, to determine if volume size is
156157
supported.
158+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
159+
model details from (default: None).
157160
region (Optional[str]): Region for which to retrieve kwargs.
158161
(Default: None).
159162
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -177,6 +180,7 @@ def _retrieve_estimator_init_kwargs(
177180
model_specs = verify_model_region_and_return_specs(
178181
model_id=model_id,
179182
version=model_version,
183+
hub_arn=hub_arn,
180184
scope=JumpStartScriptScope.TRAINING,
181185
region=region,
182186
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/metric_definitions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def _retrieve_default_training_metric_definitions(
3131
model_id: str,
3232
model_version: str,
3333
region: Optional[str],
34+
hub_arn: Optional[str] = None,
3435
tolerate_vulnerable_model: bool = False,
3536
tolerate_deprecated_model: bool = False,
3637
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -45,6 +46,8 @@ def _retrieve_default_training_metric_definitions(
4546
default training metric definitions.
4647
region (Optional[str]): Region for which to retrieve default training metric
4748
definitions.
49+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
50+
model details from (default: None).
4851
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4952
specifications should be tolerated (exception not raised). If False, raises an
5053
exception if the script used by this version of the model has dependencies with known
@@ -68,6 +71,7 @@ def _retrieve_default_training_metric_definitions(
6871
model_specs = verify_model_region_and_return_specs(
6972
model_id=model_id,
7073
version=model_version,
74+
hub_arn=hub_arn,
7175
scope=JumpStartScriptScope.TRAINING,
7276
region=region,
7377
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def _retrieve_model_package_model_artifact_s3_uri(
110110
model_id: str,
111111
model_version: str,
112112
region: Optional[str],
113+
hub_arn: Optional[str] = None,
113114
scope: Optional[str] = None,
114115
tolerate_vulnerable_model: bool = False,
115116
tolerate_deprecated_model: bool = False,
@@ -124,6 +125,8 @@ def _retrieve_model_package_model_artifact_s3_uri(
124125
model package artifact.
125126
region (Optional[str]): Region for which to retrieve the model package artifact.
126127
(Default: None).
128+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
129+
model details from (default: None).
127130
scope (Optional[str]): Scope for which to retrieve the model package artifact.
128131
(Default: None).
129132
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -152,6 +155,7 @@ def _retrieve_model_package_model_artifact_s3_uri(
152155
model_specs = verify_model_region_and_return_specs(
153156
model_id=model_id,
154157
version=model_version,
158+
hub_arn=hub_arn,
155159
scope=scope,
156160
region=region,
157161
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def _model_supports_training_model_uri(
178178
model_id: str,
179179
model_version: str,
180180
region: Optional[str],
181+
hub_arn: Optional[str] = None,
181182
tolerate_vulnerable_model: bool = False,
182183
tolerate_deprecated_model: bool = False,
183184
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -191,6 +192,8 @@ def _model_supports_training_model_uri(
191192
support status for model uri with training.
192193
region (Optional[str]): Region for which to retrieve the
193194
support status for model uri with training.
195+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
196+
model details from (default: None).
194197
tolerate_vulnerable_model (bool): True if vulnerable versions of model
195198
specifications should be tolerated (exception not raised). If False, raises an
196199
exception if the script used by this version of the model has dependencies with known
@@ -212,6 +215,7 @@ def _model_supports_training_model_uri(
212215
model_specs = verify_model_region_and_return_specs(
213216
model_id=model_id,
214217
version=model_version,
218+
hub_arn=hub_arn,
215219
scope=JumpStartScriptScope.TRAINING,
216220
region=region,
217221
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/payloads.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def _retrieve_example_payloads(
3232
model_id: str,
3333
model_version: str,
3434
region: Optional[str],
35+
hub_arn: Optional[str] = None,
3536
tolerate_vulnerable_model: bool = False,
3637
tolerate_deprecated_model: bool = False,
3738
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -45,6 +46,8 @@ def _retrieve_example_payloads(
4546
example payloads.
4647
region (Optional[str]): Region for which to retrieve the
4748
example payloads.
49+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
50+
model details from (default: None).
4851
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4952
specifications should be tolerated (exception not raised). If False, raises an
5053
exception if the script used by this version of the model has dependencies with known
@@ -67,6 +70,7 @@ def _retrieve_example_payloads(
6770
model_specs = verify_model_region_and_return_specs(
6871
model_id=model_id,
6972
version=model_version,
73+
hub_arn=hub_arn,
7074
scope=JumpStartScriptScope.INFERENCE,
7175
region=region,
7276
tolerate_vulnerable_model=tolerate_vulnerable_model,

0 commit comments

Comments
 (0)