diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index 3d84e314df..ead9b7425f 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -132,17 +132,20 @@ def _create_transformers_model(self) -> Type[Model]: vpc_config=self.vpc_config, ) - if self.mode == Mode.LOCAL_CONTAINER: + if not self.image_uri and self.mode == Mode.LOCAL_CONTAINER: self.image_uri = pysdk_model.serving_image_uri( self.sagemaker_session.boto_region_name, "local" ) - else: + elif not self.image_uri: self.image_uri = pysdk_model.serving_image_uri( self.sagemaker_session.boto_region_name, self.instance_type ) logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri) + if not pysdk_model.image_uri: + pysdk_model.image_uri = self.image_uri + self._original_deploy = pysdk_model.deploy pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper return pysdk_model @@ -251,13 +254,14 @@ def _set_instance(self, **kwargs): if self.mode == Mode.SAGEMAKER_ENDPOINT: if self.nb_instance_type and "instance_type" not in kwargs: kwargs.update({"instance_type": self.nb_instance_type}) + logger.info("Setting instance type to %s", self.nb_instance_type) elif self.instance_type and "instance_type" not in kwargs: kwargs.update({"instance_type": self.instance_type}) + logger.info("Setting instance type to %s", self.instance_type) else: raise ValueError( "Instance type must be provided when deploying to SageMaker Endpoint mode." ) - logger.info("Setting instance type to %s", self.instance_type) def _get_supported_version(self, hf_config, hugging_face_version, base_fw): """Uses the hugging face json config to pick supported versions""" diff --git a/tests/unit/sagemaker/serve/builder/test_transformers_builder.py b/tests/unit/sagemaker/serve/builder/test_transformers_builder.py index e17364f22d..b7e3db79d6 100644 --- a/tests/unit/sagemaker/serve/builder/test_transformers_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_transformers_builder.py @@ -58,6 +58,10 @@ mock_schema_builder = MagicMock() mock_schema_builder.sample_input = mock_sample_input mock_schema_builder.sample_output = mock_sample_output +MOCK_IMAGE_CONFIG = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0" +) class TestTransformersBuilder(unittest.TestCase): @@ -100,3 +104,43 @@ def test_build_deploy_for_transformers_local_container_and_remote_container( with self.assertRaises(ValueError) as _: model.deploy(mode=Mode.IN_PROCESS) + + @patch( + "sagemaker.serve.builder.transformers_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) + def test_image_uri( + self, + mock_get_nb_instance, + mock_telemetry, + ): + builder = ModelBuilder( + model=mock_model_id, + schema_builder=mock_schema_builder, + mode=Mode.LOCAL_CONTAINER, + image_uri=MOCK_IMAGE_CONFIG, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + predictor = model.deploy(model_data_download_timeout=1800) + + assert builder.image_uri == MOCK_IMAGE_CONFIG + assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" + assert isinstance(predictor, TransformersLocalModePredictor) + + assert builder.nb_instance_type == "ml.g5.24xlarge" + + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + assert "HF_MODEL_ID" in model.env + + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS)