Skip to content

Commit c5af857

Browse files
committed
feat: instance specific jumpstart host requirements
1 parent fc11ace commit c5af857

File tree

7 files changed

+141
-2
lines changed

7 files changed

+141
-2
lines changed

src/sagemaker/jumpstart/artifacts/resource_requirements.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""This module contains functions for obtaining JumpStart resoure requirements."""
1414
from __future__ import absolute_import
1515

16-
from typing import Optional
16+
from typing import Dict, Optional
1717

1818
from sagemaker.jumpstart.constants import (
1919
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -37,6 +37,7 @@ def _retrieve_default_resources(
3737
tolerate_vulnerable_model: bool = False,
3838
tolerate_deprecated_model: bool = False,
3939
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
40+
instance_type: Optional[str] = None,
4041
) -> ResourceRequirements:
4142
"""Retrieves the default resource requirements for the model.
4243
@@ -60,6 +61,8 @@ def _retrieve_default_resources(
6061
object, used for SageMaker interactions. If not
6162
specified, one is created using the default AWS configuration
6263
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64+
instance_type (str): An instance type to optionally supply in order to get
65+
host requirements specific for the instance type.
6366
Returns:
6467
str: The default resource requirements to use for the model or None.
6568
@@ -87,12 +90,28 @@ def _retrieve_default_resources(
8790
is_dynamic_container_deployment_supported = (
8891
model_specs.dynamic_container_deployment_supported
8992
)
90-
default_resource_requirements = model_specs.hosting_resource_requirements
93+
default_resource_requirements: Dict[str, int] = (
94+
model_specs.hosting_resource_requirements or {}
95+
)
9196
else:
9297
raise NotImplementedError(
9398
f"Unsupported script scope for retrieving default resource requirements: '{scope}'"
9499
)
95100

101+
instance_specific_resource_requirements: Dict[str, int] = (
102+
model_specs.hosting_instance_type_variants.get_instance_specific_resource_requirements(
103+
instance_type
104+
)
105+
if instance_type
106+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
107+
else {}
108+
)
109+
110+
default_resource_requirements = {
111+
**default_resource_requirements,
112+
**instance_specific_resource_requirements,
113+
}
114+
96115
if is_dynamic_container_deployment_supported:
97116
requests = {}
98117
if "num_accelerators" in default_resource_requirements:

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
481481
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
482482
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
483483
sagemaker_session=kwargs.sagemaker_session,
484+
instance_type=kwargs.instance_type,
484485
)
485486

486487
return kwargs

src/sagemaker/jumpstart/types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,29 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str
478478
instance_type=instance_type, property_name="artifact_key"
479479
)
480480

481+
def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]:
482+
"""Returns instance specific resource requirements.
483+
484+
If a value exists for both the instance family and instance type, the instance type value
485+
is chosen.
486+
"""
487+
488+
instance_specific_resource_requirements: dict = (
489+
self.variants.get(instance_type, {})
490+
.get("properties", {})
491+
.get("resource_requirements", {})
492+
)
493+
494+
instance_type_family = get_instance_type_family(instance_type)
495+
496+
instance_family_resource_requirements: dict = (
497+
self.variants.get(instance_type_family, {})
498+
.get("properties", {})
499+
.get("resource_requirements", {})
500+
)
501+
502+
return {**instance_family_resource_requirements, **instance_specific_resource_requirements}
503+
481504
def _get_instance_specific_property(
482505
self, instance_type: str, property_name: str
483506
) -> Optional[str]:

src/sagemaker/resource_requirements.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def retrieve_default(
3333
tolerate_vulnerable_model: bool = False,
3434
tolerate_deprecated_model: bool = False,
3535
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
36+
instance_type: Optional[str] = None,
3637
) -> str:
3738
"""Retrieves the default resource requirements for the model matching the given arguments.
3839
@@ -56,6 +57,8 @@ def retrieve_default(
5657
object, used for SageMaker interactions. If not
5758
specified, one is created using the default AWS configuration
5859
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
60+
instance_type (str): An instance type to optionally supply in order to get
61+
host requirements specific for the instance type.
5962
Returns:
6063
str: The default resource requirements to use for the model.
6164
@@ -79,4 +82,5 @@ def retrieve_default(
7982
tolerate_vulnerable_model,
8083
tolerate_deprecated_model,
8184
sagemaker_session=sagemaker_session,
85+
instance_type=instance_type,
8286
)

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,8 +840,22 @@
840840
"model_package_arn": "$gpu_model_package_arn",
841841
}
842842
},
843+
"g5": {
844+
"properties": {
845+
"resource_requirements": {
846+
"num_accelerators": 888810,
847+
"randon-field-2": 2222,
848+
}
849+
}
850+
},
843851
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
844852
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
853+
"ml.g5.xlarge": {
854+
"properties": {
855+
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"},
856+
"resource_requirements": {"num_accelerators": 10},
857+
}
858+
},
845859
"ml.g5.48xlarge": {
846860
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}}
847861
},
@@ -857,6 +871,12 @@
857871
"framework_version": "1.5.0",
858872
"py_version": "py3",
859873
},
874+
"dynamic_container_deployment_supported": True,
875+
"hosting_resource_requirements": {
876+
"min_memory_mb": 81999,
877+
"num_accelerators": 1,
878+
"random_field_1": 1,
879+
},
860880
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
861881
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
862882
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"variants": {
3535
"ml.p2.12xlarge": {
3636
"properties": {
37+
"resource_requirements": {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9},
3738
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"},
3839
"supported_inference_instance_types": ["ml.p5.xlarge"],
3940
"default_inference_instance_type": "ml.p5.xlarge",
@@ -60,6 +61,11 @@
6061
"p2": {
6162
"regional_properties": {"image_uri": "$gpu_image_uri"},
6263
"properties": {
64+
"resource_requirements": {
65+
"req2": {"2": 5, "9": 999},
66+
"req3": 999,
67+
"req4": "blah",
68+
},
6369
"supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"],
6470
"default_inference_instance_type": "ml.p2.xlarge",
6571
"metrics": [
@@ -879,3 +885,20 @@ def test_jumpstart_training_artifact_key_instance_variants():
879885
)
880886
is None
881887
)
888+
889+
890+
def test_jumpstart_resource_requirements_instance_variants():
891+
assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
892+
instance_type="ml.p2.xlarge"
893+
) == {"req2": {"2": 5, "9": 999}, "req3": 999, "req4": "blah"}
894+
895+
assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
896+
instance_type="ml.p2.12xlarge"
897+
) == {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9, "req4": "blah"}
898+
899+
assert (
900+
INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
901+
instance_type="ml.p99.12xlarge"
902+
)
903+
== {}
904+
)

tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,55 @@ def test_jumpstart_resource_requirements(patched_get_model_specs):
5050
patched_get_model_specs.reset_mock()
5151

5252

53+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
54+
def test_jumpstart_resource_requirements_instance_type_variants(patched_get_model_specs):
55+
56+
patched_get_model_specs.side_effect = get_special_model_spec
57+
region = "us-west-2"
58+
mock_client = boto3.client("s3")
59+
mock_session = Mock(s3_client=mock_client)
60+
61+
model_id, model_version = "variant-model", "*"
62+
default_inference_resource_requirements = resource_requirements.retrieve_default(
63+
region=region,
64+
model_id=model_id,
65+
model_version=model_version,
66+
scope="inference",
67+
sagemaker_session=mock_session,
68+
instance_type="ml.g5.xlarge",
69+
)
70+
assert default_inference_resource_requirements.requests == {
71+
"memory": 81999,
72+
"num_accelerators": 10,
73+
}
74+
75+
default_inference_resource_requirements = resource_requirements.retrieve_default(
76+
region=region,
77+
model_id=model_id,
78+
model_version=model_version,
79+
scope="inference",
80+
sagemaker_session=mock_session,
81+
instance_type="ml.g5.555xlarge",
82+
)
83+
assert default_inference_resource_requirements.requests == {
84+
"memory": 81999,
85+
"num_accelerators": 888810,
86+
}
87+
88+
default_inference_resource_requirements = resource_requirements.retrieve_default(
89+
region=region,
90+
model_id=model_id,
91+
model_version=model_version,
92+
scope="inference",
93+
sagemaker_session=mock_session,
94+
instance_type="ml.f9.555xlarge",
95+
)
96+
assert default_inference_resource_requirements.requests == {
97+
"memory": 81999,
98+
"num_accelerators": 1,
99+
}
100+
101+
53102
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
54103
def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs):
55104
patched_get_model_specs.side_effect = get_special_model_spec

0 commit comments

Comments
 (0)