diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 8cb42689fe..bc31e8d323 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -300,6 +300,11 @@ def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800) returns: Tuned Model. """ + if self.mode == Mode.SAGEMAKER_ENDPOINT: + logger.warning( + "Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER + ) + return self.pysdk_model num_shard_env_var_name = "SM_NUM_GPUS" if "OPTION_TENSOR_PARALLEL_DEGREE" in self.pysdk_model.env.keys(): @@ -468,58 +473,47 @@ def _build_for_jumpstart(self): self.secret_key = None self.jumpstart = True - self.pysdk_model = self._create_pre_trained_js_model() - self.pysdk_model.tune = lambda *args, **kwargs: self._default_tune() - - logger.info( - "JumpStart ID %s is packaged with Image URI: %s", self.model, self.pysdk_model.image_uri - ) - - if self.mode != Mode.SAGEMAKER_ENDPOINT: - if self._is_gated_model(self.pysdk_model): - raise ValueError( - "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode." - ) - - if "djl-inference" in self.pysdk_model.image_uri: - logger.info("Building for DJL JumpStart Model ID...") - self.model_server = ModelServer.DJL_SERVING - self.image_uri = self.pysdk_model.image_uri - - self._build_for_djl_jumpstart() - - self.pysdk_model.tune = self.tune_for_djl_jumpstart - elif "tgi-inference" in self.pysdk_model.image_uri: - logger.info("Building for TGI JumpStart Model ID...") - self.model_server = ModelServer.TGI - self.image_uri = self.pysdk_model.image_uri - - self._build_for_tgi_jumpstart() + pysdk_model = self._create_pre_trained_js_model() + image_uri = pysdk_model.image_uri - self.pysdk_model.tune = self.tune_for_tgi_jumpstart - elif "huggingface-pytorch-inference:" in self.pysdk_model.image_uri: - logger.info("Building for MMS JumpStart Model ID...") - self.model_server = ModelServer.MMS - self.image_uri = self.pysdk_model.image_uri + logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri) - self._build_for_mms_jumpstart() - else: - raise ValueError( - "JumpStart Model ID was not packaged " - "with djl-inference, tgi-inference, or mms-inference container." - ) - - return self.pysdk_model + if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT: + raise ValueError( + "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode." + ) - def _default_tune(self): - """Logs a warning message if tune is invoked on endpoint mode. + if "djl-inference" in image_uri: + logger.info("Building for DJL JumpStart Model ID...") + self.model_server = ModelServer.DJL_SERVING + self.pysdk_model = pysdk_model + self.image_uri = self.pysdk_model.image_uri + + self._build_for_djl_jumpstart() + + self.pysdk_model.tune = self.tune_for_djl_jumpstart + elif "tgi-inference" in image_uri: + logger.info("Building for TGI JumpStart Model ID...") + self.model_server = ModelServer.TGI + self.pysdk_model = pysdk_model + self.image_uri = self.pysdk_model.image_uri + + self._build_for_tgi_jumpstart() + + self.pysdk_model.tune = self.tune_for_tgi_jumpstart + elif "huggingface-pytorch-inference:" in image_uri: + logger.info("Building for MMS JumpStart Model ID...") + self.model_server = ModelServer.MMS + self.pysdk_model = pysdk_model + self.image_uri = self.pysdk_model.image_uri + + self._build_for_mms_jumpstart() + elif self.mode != Mode.SAGEMAKER_ENDPOINT: + raise ValueError( + "JumpStart Model ID was not packaged " + "with djl-inference, tgi-inference, or mms-inference container." + ) - Returns: - Jumpstart Model: ``This`` model - """ - logger.warning( - "Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER - ) return self.pysdk_model def _is_gated_model(self, model) -> bool: diff --git a/tests/integ/sagemaker/serve/test_serve_js_happy.py b/tests/integ/sagemaker/serve/test_serve_js_happy.py index 7835c8ae3c..ad0527fcc0 100644 --- a/tests/integ/sagemaker/serve/test_serve_js_happy.py +++ b/tests/integ/sagemaker/serve/test_serve_js_happy.py @@ -34,6 +34,14 @@ JS_MODEL_ID = "huggingface-textgeneration1-gpt-neo-125m-fp16" ROLE_NAME = "SageMakerRole" +SAMPLE_MMS_PROMPT = [ + "How cute your dog is!", + "Your dog is so cute.", + "The mitochondria is the powerhouse of the cell.", +] +SAMPLE_MMS_RESPONSE = {"embedding": []} +JS_MMS_MODEL_ID = "huggingface-sentencesimilarity-bge-m3" + @pytest.fixture def happy_model_builder(sagemaker_session): @@ -46,6 +54,17 @@ def happy_model_builder(sagemaker_session): ) +@pytest.fixture +def happy_mms_model_builder(sagemaker_session): + iam_client = sagemaker_session.boto_session.client("iam") + return ModelBuilder( + model=JS_MMS_MODEL_ID, + schema_builder=SchemaBuilder(SAMPLE_MMS_PROMPT, SAMPLE_MMS_RESPONSE), + role_arn=iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"], + sagemaker_session=sagemaker_session, + ) + + @pytest.mark.skipif( PYTHON_VERSION_IS_NOT_310, reason="The goal of these test are to test the serving components of our feature", @@ -75,3 +94,34 @@ def test_happy_tgi_sagemaker_endpoint(happy_model_builder, gpu_instance_type): ) if caught_ex: raise caught_ex + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="The goal of these test are to test the serving components of our feature", +) +@pytest.mark.slow_test +def test_happy_mms_sagemaker_endpoint(happy_mms_model_builder, gpu_instance_type): + logger.info("Running in SAGEMAKER_ENDPOINT mode...") + caught_ex = None + model = happy_mms_model_builder.build() + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy(instance_type=gpu_instance_type, endpoint_logging=False) + logger.info("Endpoint successfully deployed.") + + updated_sample_input = happy_mms_model_builder.schema_builder.sample_input + + predictor.predict(updated_sample_input) + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=happy_mms_model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + raise caught_ex