23
23
from sagemaker import model_uris
24
24
from sagemaker .serve .model_server .djl_serving .prepare import prepare_djl_js_resources
25
25
from sagemaker .serve .model_server .djl_serving .utils import _get_admissible_tensor_parallel_degrees
26
+ from sagemaker .serve .model_server .multi_model_server .prepare import prepare_mms_js_resources
26
27
from sagemaker .serve .model_server .tgi .prepare import prepare_tgi_js_resources , _create_dir_structure
27
28
from sagemaker .serve .mode .function_pointers import Mode
28
29
from sagemaker .serve .utils .exceptions import (
35
36
from sagemaker .serve .utils .predictors import (
36
37
DjlLocalModePredictor ,
37
38
TgiLocalModePredictor ,
39
+ TransformersLocalModePredictor ,
38
40
)
39
41
from sagemaker .serve .utils .local_hardware import (
40
42
_get_nb_instance ,
@@ -90,6 +92,7 @@ def __init__(self):
90
92
self .existing_properties = None
91
93
self .prepared_for_tgi = None
92
94
self .prepared_for_djl = None
95
+ self .prepared_for_mms = None
93
96
self .schema_builder = None
94
97
self .nb_instance_type = None
95
98
self .ram_usage_model_load = None
@@ -137,7 +140,11 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
137
140
138
141
if overwrite_mode == Mode .SAGEMAKER_ENDPOINT :
139
142
self .mode = self .pysdk_model .mode = Mode .SAGEMAKER_ENDPOINT
140
- if not hasattr (self , "prepared_for_djl" ) or not hasattr (self , "prepared_for_tgi" ):
143
+ if (
144
+ not hasattr (self , "prepared_for_djl" )
145
+ or not hasattr (self , "prepared_for_tgi" )
146
+ or not hasattr (self , "prepared_for_mms" )
147
+ ):
141
148
self .pysdk_model .model_data , env = self ._prepare_for_mode ()
142
149
elif overwrite_mode == Mode .LOCAL_CONTAINER :
143
150
self .mode = self .pysdk_model .mode = Mode .LOCAL_CONTAINER
@@ -160,6 +167,13 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
160
167
dependencies = self .dependencies ,
161
168
model_data = self .pysdk_model .model_data ,
162
169
)
170
+ elif not hasattr (self , "prepared_for_mms" ):
171
+ self .js_model_config , self .prepared_for_mms = prepare_mms_js_resources (
172
+ model_path = self .model_path ,
173
+ js_id = self .model ,
174
+ dependencies = self .dependencies ,
175
+ model_data = self .pysdk_model .model_data ,
176
+ )
163
177
164
178
self ._prepare_for_mode ()
165
179
env = {}
@@ -179,6 +193,10 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
179
193
predictor = TgiLocalModePredictor (
180
194
self .modes [str (Mode .LOCAL_CONTAINER )], serializer , deserializer
181
195
)
196
+ elif self .model_server == ModelServer .MMS :
197
+ predictor = TransformersLocalModePredictor (
198
+ self .modes [str (Mode .LOCAL_CONTAINER )], serializer , deserializer
199
+ )
182
200
183
201
ram_usage_before = _get_ram_usage_mb ()
184
202
self .modes [str (Mode .LOCAL_CONTAINER )].create_server (
@@ -254,6 +272,24 @@ def _build_for_tgi_jumpstart(self):
254
272
255
273
self .pysdk_model .env .update (env )
256
274
275
+ def _build_for_mms_jumpstart (self ):
276
+ """Placeholder docstring"""
277
+
278
+ env = {}
279
+ if self .mode == Mode .LOCAL_CONTAINER :
280
+ if not hasattr (self , "prepared_for_mms" ):
281
+ self .js_model_config , self .prepared_for_mms = prepare_mms_js_resources (
282
+ model_path = self .model_path ,
283
+ js_id = self .model ,
284
+ dependencies = self .dependencies ,
285
+ model_data = self .pysdk_model .model_data ,
286
+ )
287
+ self ._prepare_for_mode ()
288
+ elif self .mode == Mode .SAGEMAKER_ENDPOINT and hasattr (self , "prepared_for_mms" ):
289
+ self .pysdk_model .model_data , env = self ._prepare_for_mode ()
290
+
291
+ self .pysdk_model .env .update (env )
292
+
257
293
def _tune_for_js (self , sharded_supported : bool , max_tuning_duration : int = 1800 ):
258
294
"""Tune for Jumpstart Models in Local Mode.
259
295
@@ -264,11 +300,6 @@ def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800)
264
300
returns:
265
301
Tuned Model.
266
302
"""
267
- if self .mode != Mode .LOCAL_CONTAINER :
268
- logger .warning (
269
- "Tuning is only a %s capability. Returning original model." , Mode .LOCAL_CONTAINER
270
- )
271
- return self .pysdk_model
272
303
273
304
num_shard_env_var_name = "SM_NUM_GPUS"
274
305
if "OPTION_TENSOR_PARALLEL_DEGREE" in self .pysdk_model .env .keys ():
@@ -437,42 +468,58 @@ def _build_for_jumpstart(self):
437
468
self .secret_key = None
438
469
self .jumpstart = True
439
470
440
- pysdk_model = self ._create_pre_trained_js_model ()
471
+ self .pysdk_model = self ._create_pre_trained_js_model ()
472
+ self .pysdk_model .tune = lambda * args , ** kwargs : self ._default_tune ()
441
473
442
- image_uri = pysdk_model .image_uri
474
+ logger .info (
475
+ "JumpStart ID %s is packaged with Image URI: %s" , self .model , self .pysdk_model .image_uri
476
+ )
443
477
444
- logger .info ("JumpStart ID %s is packaged with Image URI: %s" , self .model , image_uri )
478
+ if self .mode != Mode .SAGEMAKER_ENDPOINT :
479
+ if self ._is_gated_model (self .pysdk_model ):
480
+ raise ValueError (
481
+ "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
482
+ )
445
483
446
- if self . _is_gated_model ( pysdk_model ) and self .mode != Mode . SAGEMAKER_ENDPOINT :
447
- raise ValueError (
448
- "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
449
- )
484
+ if "djl-inference" in self .pysdk_model . image_uri :
485
+ logger . info ( "Building for DJL JumpStart Model ID..." )
486
+ self . model_server = ModelServer . DJL_SERVING
487
+ self . image_uri = self . pysdk_model . image_uri
450
488
451
- if "djl-inference" in image_uri :
452
- logger .info ("Building for DJL JumpStart Model ID..." )
453
- self .model_server = ModelServer .DJL_SERVING
489
+ self ._build_for_djl_jumpstart ()
454
490
455
- self .pysdk_model = pysdk_model
456
- self .image_uri = self .pysdk_model .image_uri
491
+ self .pysdk_model .tune = self .tune_for_djl_jumpstart
492
+ elif "tgi-inference" in self .pysdk_model .image_uri :
493
+ logger .info ("Building for TGI JumpStart Model ID..." )
494
+ self .model_server = ModelServer .TGI
495
+ self .image_uri = self .pysdk_model .image_uri
457
496
458
- self ._build_for_djl_jumpstart ()
497
+ self ._build_for_tgi_jumpstart ()
459
498
460
- self .pysdk_model .tune = self .tune_for_djl_jumpstart
461
- elif "tgi-inference" in image_uri :
462
- logger .info ("Building for TGI JumpStart Model ID..." )
463
- self .model_server = ModelServer .TGI
499
+ self .pysdk_model .tune = self .tune_for_tgi_jumpstart
500
+ elif "huggingface-pytorch-inference:" in self .pysdk_model .image_uri :
501
+ logger .info ("Building for MMS JumpStart Model ID..." )
502
+ self .model_server = ModelServer .MMS
503
+ self .image_uri = self .pysdk_model .image_uri
464
504
465
- self .pysdk_model = pysdk_model
466
- self .image_uri = self .pysdk_model .image_uri
505
+ self ._build_for_mms_jumpstart ()
506
+ else :
507
+ raise ValueError (
508
+ "JumpStart Model ID was not packaged "
509
+ "with djl-inference, tgi-inference, or mms-inference container."
510
+ )
467
511
468
- self ._build_for_tgi_jumpstart ()
512
+ return self .pysdk_model
469
513
470
- self .pysdk_model .tune = self .tune_for_tgi_jumpstart
471
- else :
472
- raise ValueError (
473
- "JumpStart Model ID was not packaged with djl-inference or tgi-inference container."
474
- )
514
+ def _default_tune (self ):
515
+ """Logs a warning message if tune is invoked on endpoint mode.
475
516
517
+ Returns:
518
+ Jumpstart Model: ``This`` model
519
+ """
520
+ logger .warning (
521
+ "Tuning is only a %s capability. Returning original model." , Mode .LOCAL_CONTAINER
522
+ )
476
523
return self .pysdk_model
477
524
478
525
def _is_gated_model (self , model ) -> bool :
0 commit comments