Skip to content

Commit 43570b8

Browse files
evakravibencrabtree
authored andcommitted
feat: instance specific jumpstart host requirements (aws#4397)
* feat: instance specific jumpstart host requirements * chore: add js support for copies resource requirement, enforce coupling with ResourceRequirements class * fix: typing * fix: pylint
1 parent 74bbb09 commit 43570b8

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

src/sagemaker/jumpstart/artifacts/resource_requirements.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _retrieve_default_resources(
134134
}
135135

136136
if is_dynamic_container_deployment_supported:
137+
137138
all_resource_requirement_kwargs = {}
138139

139140
for (

tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,54 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode
110110
}
111111

112112

113+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
114+
def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs):
115+
patched_get_model_specs.side_effect = get_special_model_spec
116+
region = "us-west-2"
117+
mock_client = boto3.client("s3")
118+
mock_session = Mock(s3_client=mock_client)
119+
120+
model_id, model_version = "variant-model", "*"
121+
default_inference_resource_requirements = resource_requirements.retrieve_default(
122+
region=region,
123+
model_id=model_id,
124+
model_version=model_version,
125+
scope="inference",
126+
sagemaker_session=mock_session,
127+
instance_type="ml.g5.xlarge",
128+
)
129+
assert default_inference_resource_requirements.requests == {
130+
"memory": 81999,
131+
"num_accelerators": 10,
132+
}
133+
134+
default_inference_resource_requirements = resource_requirements.retrieve_default(
135+
region=region,
136+
model_id=model_id,
137+
model_version=model_version,
138+
scope="inference",
139+
sagemaker_session=mock_session,
140+
instance_type="ml.g5.555xlarge",
141+
)
142+
assert default_inference_resource_requirements.requests == {
143+
"memory": 81999,
144+
"num_accelerators": 888810,
145+
}
146+
147+
default_inference_resource_requirements = resource_requirements.retrieve_default(
148+
region=region,
149+
model_id=model_id,
150+
model_version=model_version,
151+
scope="inference",
152+
sagemaker_session=mock_session,
153+
instance_type="ml.f9.555xlarge",
154+
)
155+
assert default_inference_resource_requirements.requests == {
156+
"memory": 81999,
157+
"num_accelerators": 1,
158+
}
159+
160+
113161
@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type")
114162
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
115163
def test_jumpstart_no_supported_resource_requirements(

0 commit comments

Comments
 (0)