@@ -110,7 +110,7 @@ def test_build_deploy_for_transformers_local_container_and_remote_container(
110
110
return_value = "ml.g5.24xlarge" ,
111
111
)
112
112
@patch ("sagemaker.serve.builder.transformers_builder._capture_telemetry" , side_effect = None )
113
- def test_image_uri (
113
+ def test_image_uri_override (
114
114
self ,
115
115
mock_get_nb_instance ,
116
116
mock_telemetry ,
@@ -144,3 +144,52 @@ def test_image_uri(
144
144
145
145
with self .assertRaises (ValueError ) as _ :
146
146
model .deploy (mode = Mode .IN_PROCESS )
147
+
148
+ @patch (
149
+ "sagemaker.serve.builder.transformers_builder._get_nb_instance" ,
150
+ return_value = "ml.g5.24xlarge" ,
151
+ )
152
+ @patch (
153
+ "sagemaker.huggingface.llm_utils.get_huggingface_model_metadata" ,
154
+ return_value = "sentence-similarity" ,
155
+ )
156
+ @patch (
157
+ "from sagemaker.huggingface.get_huggingface_llm_image_uri" ,
158
+ return_value = MOCK_IMAGE_CONFIG
159
+ )
160
+ @patch ("sagemaker.serve.builder.transformers_builder._capture_telemetry" , side_effect = None )
161
+ def test_sentence_similarity_support (
162
+ self ,
163
+ mock_get_nb_instance ,
164
+ mock_task ,
165
+ mock_image ,
166
+ mock_telemetry ,
167
+ ):
168
+ builder = ModelBuilder (
169
+ model = mock_model_id ,
170
+ schema_builder = mock_schema_builder ,
171
+ mode = Mode .LOCAL_CONTAINER ,
172
+ )
173
+
174
+ builder ._prepare_for_mode = MagicMock ()
175
+ builder ._prepare_for_mode .side_effect = None
176
+
177
+ model = builder .build ()
178
+ builder .serve_settings .telemetry_opt_out = True
179
+
180
+ builder .modes [str (Mode .LOCAL_CONTAINER )] = MagicMock ()
181
+ predictor = model .deploy (model_data_download_timeout = 1800 )
182
+
183
+ assert builder .image_uri == MOCK_IMAGE_CONFIG
184
+ assert builder .env_vars ["MODEL_LOADING_TIMEOUT" ] == "1800"
185
+ assert isinstance (predictor , TransformersLocalModePredictor )
186
+
187
+ assert builder .nb_instance_type == "ml.g5.24xlarge"
188
+
189
+ builder ._original_deploy = MagicMock ()
190
+ builder ._prepare_for_mode .return_value = (None , {})
191
+ predictor = model .deploy (mode = Mode .SAGEMAKER_ENDPOINT , role = "mock_role_arn" )
192
+ assert "HF_MODEL_ID" in model .env
193
+
194
+ with self .assertRaises (ValueError ) as _ :
195
+ model .deploy (mode = Mode .IN_PROCESS )
0 commit comments