Skip to content

Commit 20b51c7

Browse files
Captainiabenieric
authored andcommitted
fix: populate default config name to model (aws#4617)
* fix: populate default config name to model * update condition * fix * format * flake8 * fix tests * fix coverage * temporarily skip integ test vulnerbility * fix tolerate attach method * format * fix predictor * format
1 parent 1dabf41 commit 20b51c7

File tree

8 files changed

+63
-10
lines changed

8 files changed

+63
-10
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,31 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
543543
return kwargs
544544

545545

546+
def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
547+
"""Sets default config name to the kwargs. Returns full kwargs."""
548+
549+
specs = verify_model_region_and_return_specs(
550+
model_id=kwargs.model_id,
551+
version=kwargs.model_version,
552+
scope=JumpStartScriptScope.INFERENCE,
553+
region=kwargs.region,
554+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
555+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
556+
sagemaker_session=kwargs.sagemaker_session,
557+
model_type=kwargs.model_type,
558+
config_name=kwargs.config_name,
559+
)
560+
if (
561+
specs.inference_configs
562+
and specs.inference_configs.get_top_config_from_ranking().config_name
563+
):
564+
kwargs.config_name = (
565+
kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name
566+
)
567+
568+
return kwargs
569+
570+
546571
def get_deploy_kwargs(
547572
model_id: str,
548573
model_version: Optional[str] = None,
@@ -808,5 +833,6 @@ def get_init_kwargs(
808833
model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs)
809834

810835
model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
836+
model_init_kwargs = _add_config_name_to_kwargs(kwargs=model_init_kwargs)
811837

812838
return model_init_kwargs

src/sagemaker/jumpstart/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def _validate_model_id_and_type():
351351
self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model
352352
self.region = model_init_kwargs.region
353353
self.sagemaker_session = model_init_kwargs.sagemaker_session
354-
self.config_name = config_name
354+
self.config_name = model_init_kwargs.config_name
355355

356356
if self.model_type == JumpStartModelType.PROPRIETARY:
357357
self.log_subscription_warning()

src/sagemaker/jumpstart/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,10 +1076,12 @@ class JumpStartMetadataConfig(JumpStartDataHolderType):
10761076
"benchmark_metrics",
10771077
"config_components",
10781078
"resolved_metadata_config",
1079+
"config_name",
10791080
]
10801081

10811082
def __init__(
10821083
self,
1084+
config_name: str,
10831085
base_fields: Dict[str, Any],
10841086
config_components: Dict[str, JumpStartConfigComponent],
10851087
benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]],
@@ -1098,6 +1100,7 @@ def __init__(
10981100
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
10991101
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics
11001102
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
1103+
self.config_name: Optional[str] = config_name
11011104

11021105
def to_json(self) -> Dict[str, Any]:
11031106
"""Returns json representation of JumpStartMetadataConfig object."""
@@ -1251,6 +1254,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12511254
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
12521255
{
12531256
alias: JumpStartMetadataConfig(
1257+
alias,
12541258
json_obj,
12551259
(
12561260
{
@@ -1303,6 +1307,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
13031307
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
13041308
{
13051309
alias: JumpStartMetadataConfig(
1310+
alias,
13061311
json_obj,
13071312
(
13081313
{

src/sagemaker/predictor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def retrieve_default(
4343
tolerate_vulnerable_model: bool = False,
4444
tolerate_deprecated_model: bool = False,
4545
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
46+
config_name: Optional[str] = None,
4647
) -> Predictor:
4748
"""Retrieves the default predictor for the model matching the given arguments.
4849
@@ -65,6 +66,8 @@ def retrieve_default(
6566
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
6667
(exception not raised). False if these models should raise an exception.
6768
(Default: False).
69+
config_name (Optional[str]): The name of the configuration to use for the
70+
predictor. (Default: None)
6871
Returns:
6972
Predictor: The default predictor to use for the model.
7073
@@ -91,10 +94,9 @@ def retrieve_default(
9194
model_id = inferred_model_id
9295
model_version = model_version or inferred_model_version or "*"
9396
inference_component_name = inference_component_name or inferred_inference_component_name
94-
config_name = inferred_config_name or None
97+
config_name = config_name or inferred_config_name or None
9598
else:
9699
model_version = model_version or "*"
97-
config_name = None
98100

99101
predictor = Predictor(
100102
endpoint_name=endpoint_name,

tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def test_gated_model_training_v2(setup):
150150
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
151151
environment={"accept_eula": "true"},
152152
max_run=259200, # avoid exceeding resource limits
153-
tolerate_vulnerable_model=True, # tolerate old version of model
153+
tolerate_vulnerable_model=True, # TODO: remove once vulnerbility is patched
154154
)
155155

156156
# uses ml.g5.12xlarge instance

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,8 @@ def test_jumpstart_estimator_attach_eula_model(
10111011
additional_kwargs={
10121012
"model_id": "gemma-model",
10131013
"model_version": "*",
1014+
"tolerate_vulnerable_model": True,
1015+
"tolerate_deprecated_model": True,
10141016
"environment": {"accept_eula": "true"},
10151017
"tolerate_vulnerable_model": True,
10161018
"tolerate_deprecated_model": True,

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,6 +1552,8 @@ def test_model_initialization_with_config_name(
15521552

15531553
model = JumpStartModel(model_id=model_id, config_name="neuron-inference")
15541554

1555+
assert model.config_name == "neuron-inference"
1556+
15551557
model.deploy()
15561558

15571559
mock_model_deploy.assert_called_once_with(
@@ -1594,6 +1596,8 @@ def test_model_set_deployment_config(
15941596

15951597
model = JumpStartModel(model_id=model_id)
15961598

1599+
assert model.config_name is None
1600+
15971601
model.deploy()
15981602

15991603
mock_model_deploy.assert_called_once_with(
@@ -1612,6 +1616,8 @@ def test_model_set_deployment_config(
16121616
mock_get_model_specs.side_effect = get_prototype_spec_with_configs
16131617
model.set_deployment_config("neuron-inference")
16141618

1619+
assert model.config_name == "neuron-inference"
1620+
16151621
model.deploy()
16161622

16171623
mock_model_deploy.assert_called_once_with(
@@ -1654,6 +1660,8 @@ def test_model_unset_deployment_config(
16541660

16551661
model = JumpStartModel(model_id=model_id, config_name="neuron-inference")
16561662

1663+
assert model.config_name == "neuron-inference"
1664+
16571665
model.deploy()
16581666

16591667
mock_model_deploy.assert_called_once_with(
@@ -1789,7 +1797,6 @@ def test_model_retrieve_deployment_config(
17891797
):
17901798
model_id, _ = "pytorch-eqa-bert-base-cased", "*"
17911799

1792-
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)
17931800
mock_verify_model_region_and_return_specs.side_effect = (
17941801
lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks()
17951802
)
@@ -1804,15 +1811,23 @@ def test_model_retrieve_deployment_config(
18041811
)
18051812
mock_model_deploy.return_value = default_predictor
18061813

1814+
expected = get_base_deployment_configs()[0]
1815+
config_name = expected.get("DeploymentConfigName")
1816+
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(
1817+
model_id, config_name
1818+
)
1819+
18071820
mock_session.return_value = sagemaker_session
18081821

18091822
model = JumpStartModel(model_id=model_id)
18101823

1811-
expected = get_base_deployment_configs()[0]
1812-
model.set_deployment_config(expected.get("DeploymentConfigName"))
1824+
model.set_deployment_config(config_name)
18131825

18141826
self.assertEqual(model.deployment_config, expected)
18151827

1828+
mock_get_init_kwargs.reset_mock()
1829+
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)
1830+
18161831
# Unset
18171832
model.set_deployment_config(None)
18181833
self.assertIsNone(model.deployment_config)

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414
import copy
15-
from typing import List, Dict, Any
15+
from typing import List, Dict, Any, Optional
1616
import boto3
1717

1818
from sagemaker.compute_resource_requirements import ResourceRequirements
@@ -237,7 +237,7 @@ def get_base_spec_with_prototype_configs_with_missing_benchmarks(
237237
copy_inference_configs = copy.deepcopy(INFERENCE_CONFIGS)
238238
copy_inference_configs["inference_configs"]["neuron-inference"]["benchmark_metrics"] = None
239239

240-
inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS}
240+
inference_configs = {**copy_inference_configs, **INFERENCE_CONFIG_RANKINGS}
241241
training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS}
242242

243243
spec.update(inference_configs)
@@ -335,7 +335,9 @@ def get_base_deployment_configs_with_acceleration_configs() -> List[Dict[str, An
335335
return configs
336336

337337

338-
def get_mock_init_kwargs(model_id) -> JumpStartModelInitKwargs:
338+
def get_mock_init_kwargs(
339+
model_id: str, config_name: Optional[str] = None
340+
) -> JumpStartModelInitKwargs:
339341
return JumpStartModelInitKwargs(
340342
model_id=model_id,
341343
model_type=JumpStartModelType.OPEN_WEIGHTS,
@@ -344,4 +346,5 @@ def get_mock_init_kwargs(model_id) -> JumpStartModelInitKwargs:
344346
instance_type=INIT_KWARGS.get("instance_type"),
345347
env=INIT_KWARGS.get("env"),
346348
resources=ResourceRequirements(),
349+
config_name=config_name,
347350
)

0 commit comments

Comments
 (0)