Skip to content

Commit 8002d7f

Browse files
makungaj1Jonathan Makunga
and
Jonathan Makunga
authored
fix: model builder limited container support for endpoint mode. (#4683)
* Allow ModelBuilder's endpoint mode for Jumpstart models packaged with containers other than TGI and DJL * increase coverage * Add JS Support for MMS Serving * Add JS Support for MMS Serving * Unit tests * Refactoring * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 63a9ac3 commit 8002d7f

File tree

3 files changed

+191
-31
lines changed

3 files changed

+191
-31
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 78 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sagemaker import model_uris
2424
from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources
2525
from sagemaker.serve.model_server.djl_serving.utils import _get_admissible_tensor_parallel_degrees
26+
from sagemaker.serve.model_server.multi_model_server.prepare import prepare_mms_js_resources
2627
from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources, _create_dir_structure
2728
from sagemaker.serve.mode.function_pointers import Mode
2829
from sagemaker.serve.utils.exceptions import (
@@ -35,6 +36,7 @@
3536
from sagemaker.serve.utils.predictors import (
3637
DjlLocalModePredictor,
3738
TgiLocalModePredictor,
39+
TransformersLocalModePredictor,
3840
)
3941
from sagemaker.serve.utils.local_hardware import (
4042
_get_nb_instance,
@@ -90,6 +92,7 @@ def __init__(self):
9092
self.existing_properties = None
9193
self.prepared_for_tgi = None
9294
self.prepared_for_djl = None
95+
self.prepared_for_mms = None
9396
self.schema_builder = None
9497
self.nb_instance_type = None
9598
self.ram_usage_model_load = None
@@ -137,7 +140,11 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
137140

138141
if overwrite_mode == Mode.SAGEMAKER_ENDPOINT:
139142
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT
140-
if not hasattr(self, "prepared_for_djl") or not hasattr(self, "prepared_for_tgi"):
143+
if (
144+
not hasattr(self, "prepared_for_djl")
145+
or not hasattr(self, "prepared_for_tgi")
146+
or not hasattr(self, "prepared_for_mms")
147+
):
141148
self.pysdk_model.model_data, env = self._prepare_for_mode()
142149
elif overwrite_mode == Mode.LOCAL_CONTAINER:
143150
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
@@ -160,6 +167,13 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
160167
dependencies=self.dependencies,
161168
model_data=self.pysdk_model.model_data,
162169
)
170+
elif not hasattr(self, "prepared_for_mms"):
171+
self.js_model_config, self.prepared_for_mms = prepare_mms_js_resources(
172+
model_path=self.model_path,
173+
js_id=self.model,
174+
dependencies=self.dependencies,
175+
model_data=self.pysdk_model.model_data,
176+
)
163177

164178
self._prepare_for_mode()
165179
env = {}
@@ -179,6 +193,10 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
179193
predictor = TgiLocalModePredictor(
180194
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
181195
)
196+
elif self.model_server == ModelServer.MMS:
197+
predictor = TransformersLocalModePredictor(
198+
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
199+
)
182200

183201
ram_usage_before = _get_ram_usage_mb()
184202
self.modes[str(Mode.LOCAL_CONTAINER)].create_server(
@@ -254,6 +272,24 @@ def _build_for_tgi_jumpstart(self):
254272

255273
self.pysdk_model.env.update(env)
256274

275+
def _build_for_mms_jumpstart(self):
276+
"""Placeholder docstring"""
277+
278+
env = {}
279+
if self.mode == Mode.LOCAL_CONTAINER:
280+
if not hasattr(self, "prepared_for_mms"):
281+
self.js_model_config, self.prepared_for_mms = prepare_mms_js_resources(
282+
model_path=self.model_path,
283+
js_id=self.model,
284+
dependencies=self.dependencies,
285+
model_data=self.pysdk_model.model_data,
286+
)
287+
self._prepare_for_mode()
288+
elif self.mode == Mode.SAGEMAKER_ENDPOINT and hasattr(self, "prepared_for_mms"):
289+
self.pysdk_model.model_data, env = self._prepare_for_mode()
290+
291+
self.pysdk_model.env.update(env)
292+
257293
def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800):
258294
"""Tune for Jumpstart Models in Local Mode.
259295
@@ -264,11 +300,6 @@ def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800)
264300
returns:
265301
Tuned Model.
266302
"""
267-
if self.mode != Mode.LOCAL_CONTAINER:
268-
logger.warning(
269-
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
270-
)
271-
return self.pysdk_model
272303

273304
num_shard_env_var_name = "SM_NUM_GPUS"
274305
if "OPTION_TENSOR_PARALLEL_DEGREE" in self.pysdk_model.env.keys():
@@ -437,42 +468,58 @@ def _build_for_jumpstart(self):
437468
self.secret_key = None
438469
self.jumpstart = True
439470

440-
pysdk_model = self._create_pre_trained_js_model()
471+
self.pysdk_model = self._create_pre_trained_js_model()
472+
self.pysdk_model.tune = lambda *args, **kwargs: self._default_tune()
441473

442-
image_uri = pysdk_model.image_uri
474+
logger.info(
475+
"JumpStart ID %s is packaged with Image URI: %s", self.model, self.pysdk_model.image_uri
476+
)
443477

444-
logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri)
478+
if self.mode != Mode.SAGEMAKER_ENDPOINT:
479+
if self._is_gated_model(self.pysdk_model):
480+
raise ValueError(
481+
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
482+
)
445483

446-
if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT:
447-
raise ValueError(
448-
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
449-
)
484+
if "djl-inference" in self.pysdk_model.image_uri:
485+
logger.info("Building for DJL JumpStart Model ID...")
486+
self.model_server = ModelServer.DJL_SERVING
487+
self.image_uri = self.pysdk_model.image_uri
450488

451-
if "djl-inference" in image_uri:
452-
logger.info("Building for DJL JumpStart Model ID...")
453-
self.model_server = ModelServer.DJL_SERVING
489+
self._build_for_djl_jumpstart()
454490

455-
self.pysdk_model = pysdk_model
456-
self.image_uri = self.pysdk_model.image_uri
491+
self.pysdk_model.tune = self.tune_for_djl_jumpstart
492+
elif "tgi-inference" in self.pysdk_model.image_uri:
493+
logger.info("Building for TGI JumpStart Model ID...")
494+
self.model_server = ModelServer.TGI
495+
self.image_uri = self.pysdk_model.image_uri
457496

458-
self._build_for_djl_jumpstart()
497+
self._build_for_tgi_jumpstart()
459498

460-
self.pysdk_model.tune = self.tune_for_djl_jumpstart
461-
elif "tgi-inference" in image_uri:
462-
logger.info("Building for TGI JumpStart Model ID...")
463-
self.model_server = ModelServer.TGI
499+
self.pysdk_model.tune = self.tune_for_tgi_jumpstart
500+
elif "huggingface-pytorch-inference:" in self.pysdk_model.image_uri:
501+
logger.info("Building for MMS JumpStart Model ID...")
502+
self.model_server = ModelServer.MMS
503+
self.image_uri = self.pysdk_model.image_uri
464504

465-
self.pysdk_model = pysdk_model
466-
self.image_uri = self.pysdk_model.image_uri
505+
self._build_for_mms_jumpstart()
506+
else:
507+
raise ValueError(
508+
"JumpStart Model ID was not packaged "
509+
"with djl-inference, tgi-inference, or mms-inference container."
510+
)
467511

468-
self._build_for_tgi_jumpstart()
512+
return self.pysdk_model
469513

470-
self.pysdk_model.tune = self.tune_for_tgi_jumpstart
471-
else:
472-
raise ValueError(
473-
"JumpStart Model ID was not packaged with djl-inference or tgi-inference container."
474-
)
514+
def _default_tune(self):
515+
"""Logs a warning message if tune is invoked on endpoint mode.
475516
517+
Returns:
518+
Jumpstart Model: ``This`` model
519+
"""
520+
logger.warning(
521+
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
522+
)
476523
return self.pysdk_model
477524

478525
def _is_gated_model(self, model) -> bool:

src/sagemaker/serve/model_server/multi_model_server/prepare.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from __future__ import absolute_import
1616
import logging
1717
from pathlib import Path
18+
from typing import List
1819

20+
from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts
1921
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
2022

2123
logger = logging.getLogger(__name__)
@@ -36,3 +38,28 @@ def _create_dir_structure(model_path: str) -> tuple:
3638
_check_docker_disk_usage()
3739

3840
return model_path, code_dir
41+
42+
43+
def prepare_mms_js_resources(
44+
model_path: str,
45+
js_id: str,
46+
shared_libs: List[str] = None,
47+
dependencies: str = None,
48+
model_data: str = None,
49+
) -> tuple:
50+
"""Prepare serving when a JumpStart model id is given
51+
52+
Args:
53+
model_path (str) : Argument
54+
js_id (str): Argument
55+
shared_libs (List[]) : Argument
56+
dependencies (str) : Argument
57+
model_data (str) : Argument
58+
59+
Returns:
60+
( str ) :
61+
62+
"""
63+
model_path, code_dir = _create_dir_structure(model_path)
64+
65+
return _copy_jumpstart_artifacts(model_data, js_id, code_dir)

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@
6363
"123456789712.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi"
6464
"-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04"
6565
)
66+
mock_invalid_image_uri = (
67+
"123456789712.dkr.ecr.us-west-2.amazonaws.com/invalid"
68+
"-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04"
69+
)
6670
mock_djl_image_uri = (
6771
"123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1"
6872
)
@@ -82,6 +86,88 @@
8286

8387

8488
class TestJumpStartBuilder(unittest.TestCase):
89+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
90+
@patch(
91+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
92+
return_value=True,
93+
)
94+
@patch(
95+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
96+
return_value=MagicMock(),
97+
)
98+
@patch(
99+
"sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
100+
return_value=({"model_type": "t5", "n_head": 71}, True),
101+
)
102+
@patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
103+
@patch(
104+
"sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
105+
)
106+
def test__build_for_jumpstart_value_error(
107+
self,
108+
mock_get_nb_instance,
109+
mock_get_ram_usage_mb,
110+
mock_prepare_for_tgi,
111+
mock_pre_trained_model,
112+
mock_is_jumpstart_model,
113+
mock_telemetry,
114+
):
115+
builder = ModelBuilder(
116+
model="facebook/invalid",
117+
schema_builder=mock_schema_builder,
118+
mode=Mode.LOCAL_CONTAINER,
119+
)
120+
121+
mock_pre_trained_model.return_value.image_uri = mock_invalid_image_uri
122+
123+
self.assertRaises(
124+
ValueError,
125+
lambda: builder.build(),
126+
)
127+
128+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
129+
@patch(
130+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
131+
return_value=True,
132+
)
133+
@patch(
134+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
135+
return_value=MagicMock(),
136+
)
137+
@patch(
138+
"sagemaker.serve.builder.jumpstart_builder.prepare_mms_js_resources",
139+
return_value=({"model_type": "t5", "n_head": 71}, True),
140+
)
141+
@patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
142+
@patch(
143+
"sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
144+
)
145+
def test__build_for_mms_jumpstart(
146+
self,
147+
mock_get_nb_instance,
148+
mock_get_ram_usage_mb,
149+
mock_prepare_for_mms,
150+
mock_pre_trained_model,
151+
mock_is_jumpstart_model,
152+
mock_telemetry,
153+
):
154+
builder = ModelBuilder(
155+
model="facebook/galactica-mock-model-id",
156+
schema_builder=mock_schema_builder,
157+
mode=Mode.LOCAL_CONTAINER,
158+
)
159+
160+
mock_pre_trained_model.return_value.image_uri = (
161+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface"
162+
"-pytorch-inference:2.1.0-transformers4.37.0-gpu-py310-cu118"
163+
"-ubuntu20.04"
164+
)
165+
166+
builder.build()
167+
builder.serve_settings.telemetry_opt_out = True
168+
169+
mock_prepare_for_mms.assert_called()
170+
85171
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
86172
@patch(
87173
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",

0 commit comments

Comments
 (0)