@@ -110,6 +110,54 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode
110
110
}
111
111
112
112
113
+ @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
114
+ def test_jumpstart_no_supported_resource_requirements (patched_get_model_specs ):
115
+ patched_get_model_specs .side_effect = get_special_model_spec
116
+ region = "us-west-2"
117
+ mock_client = boto3 .client ("s3" )
118
+ mock_session = Mock (s3_client = mock_client )
119
+
120
+ model_id , model_version = "variant-model" , "*"
121
+ default_inference_resource_requirements = resource_requirements .retrieve_default (
122
+ region = region ,
123
+ model_id = model_id ,
124
+ model_version = model_version ,
125
+ scope = "inference" ,
126
+ sagemaker_session = mock_session ,
127
+ instance_type = "ml.g5.xlarge" ,
128
+ )
129
+ assert default_inference_resource_requirements .requests == {
130
+ "memory" : 81999 ,
131
+ "num_accelerators" : 10 ,
132
+ }
133
+
134
+ default_inference_resource_requirements = resource_requirements .retrieve_default (
135
+ region = region ,
136
+ model_id = model_id ,
137
+ model_version = model_version ,
138
+ scope = "inference" ,
139
+ sagemaker_session = mock_session ,
140
+ instance_type = "ml.g5.555xlarge" ,
141
+ )
142
+ assert default_inference_resource_requirements .requests == {
143
+ "memory" : 81999 ,
144
+ "num_accelerators" : 888810 ,
145
+ }
146
+
147
+ default_inference_resource_requirements = resource_requirements .retrieve_default (
148
+ region = region ,
149
+ model_id = model_id ,
150
+ model_version = model_version ,
151
+ scope = "inference" ,
152
+ sagemaker_session = mock_session ,
153
+ instance_type = "ml.f9.555xlarge" ,
154
+ )
155
+ assert default_inference_resource_requirements .requests == {
156
+ "memory" : 81999 ,
157
+ "num_accelerators" : 1 ,
158
+ }
159
+
160
+
113
161
@patch ("sagemaker.jumpstart.utils.validate_model_id_and_get_type" )
114
162
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
115
163
def test_jumpstart_no_supported_resource_requirements (
0 commit comments