Skip to content

Commit 0cf3191

Browse files
makungaj1Jonathan Makunga
authored andcommitted
Feat: Add TEI support for ModelBuilder (aws#4694)
* Add TEI Serving * Add TEI Serving * Add TEI Serving * Add TEI Serving * Add TEI Serving * Add TEI Serving * Notebook testing * Notebook testing * Notebook testing * Refactoring * Refactoring * UT * UT * Refactoring * Test coverage * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 279e0e6 commit 0cf3191

File tree

13 files changed

+412
-65
lines changed

13 files changed

+412
-65
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
171171
in order for model builder to build the artifacts correctly (according
172172
to the model server). Possible values for this argument are
173173
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
174-
``TRITON``, and``TGI``.
174+
``TRITON``,``TGI``, and ``TEI``.
175175
model_metadata (Optional[Dict[str, Any]): Dictionary used to override model metadata.
176176
Currently, ``HF_TASK`` is overridable for HuggingFace model. HF_TASK should be set for
177177
new models without task metadata in the Hub, adding unsupported task types will throw

src/sagemaker/serve/builder/tei_builder.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
_get_nb_instance,
2626
)
2727
from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure
28-
from sagemaker.serve.utils.predictors import TgiLocalModePredictor
28+
from sagemaker.serve.utils.predictors import TeiLocalModePredictor
2929
from sagemaker.serve.utils.types import ModelServer
3030
from sagemaker.serve.mode.function_pointers import Mode
3131
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
@@ -74,16 +74,16 @@ def _prepare_for_mode(self):
7474
def _get_client_translators(self):
7575
"""Placeholder docstring"""
7676

77-
def _set_to_tgi(self):
77+
def _set_to_tei(self):
7878
"""Placeholder docstring"""
79-
if self.model_server != ModelServer.TGI:
79+
if self.model_server != ModelServer.TEI:
8080
messaging = (
8181
"HuggingFace Model ID support on model server: "
8282
f"{self.model_server} is not currently supported. "
83-
f"Defaulting to {ModelServer.TGI}"
83+
f"Defaulting to {ModelServer.TEI}"
8484
)
8585
logger.warning(messaging)
86-
self.model_server = ModelServer.TGI
86+
self.model_server = ModelServer.TEI
8787

8888
def _create_tei_model(self, **kwargs) -> Type[Model]:
8989
"""Placeholder docstring"""
@@ -142,7 +142,7 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
142142
if self.mode == Mode.LOCAL_CONTAINER:
143143
timeout = kwargs.get("model_data_download_timeout")
144144

145-
predictor = TgiLocalModePredictor(
145+
predictor = TeiLocalModePredictor(
146146
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
147147
)
148148

@@ -180,7 +180,9 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
180180
if "endpoint_logging" not in kwargs:
181181
kwargs["endpoint_logging"] = True
182182

183-
if not self.nb_instance_type and "instance_type" not in kwargs:
183+
if self.nb_instance_type and "instance_type" not in kwargs:
184+
kwargs.update({"instance_type": self.nb_instance_type})
185+
elif not self.nb_instance_type and "instance_type" not in kwargs:
184186
raise ValueError(
185187
"Instance type must be provided when deploying " "to SageMaker Endpoint mode."
186188
)
@@ -216,7 +218,7 @@ def _build_for_tei(self):
216218
"""Placeholder docstring"""
217219
self.secret_key = None
218220

219-
self._set_to_tgi()
221+
self._set_to_tei()
220222

221223
self.pysdk_model = self._build_for_hf_tei()
222224
return self.pysdk_model

src/sagemaker/serve/mode/local_container_mode.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sagemaker.serve.model_server.triton.server import LocalTritonServer
2323
from sagemaker.serve.model_server.tgi.server import LocalTgiServing
2424
from sagemaker.serve.model_server.fastapi.server import LocalFastApi
25+
from sagemaker.serve.model_server.tei.server import LocalTeiServing
2526
from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer
2627
from sagemaker.session import Session
2728

@@ -35,18 +36,15 @@
3536
)
3637

3738

38-
<<<<<<< HEAD
39-
class LocalContainerMode(LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalFastApi, LocalMultiModelServer):
40-
=======
4139
class LocalContainerMode(
4240
LocalTorchServe,
4341
LocalDJLServing,
4442
LocalTritonServer,
4543
LocalTgiServing,
4644
LocalMultiModelServer,
4745
LocalTensorflowServing,
46+
LocalFastApi
4847
):
49-
>>>>>>> a5c6229b0 (Add tensorflow_serving support for mlflow models and enable lineage tracking for mlflow models (#4662))
5048
"""A class that holds methods to deploy model to a container in local environment"""
5149

5250
def __init__(
@@ -74,6 +72,7 @@ def __init__(
7472
self.container = None
7573
self.secret_key = None
7674
self._ping_container = None
75+
self._invoke_serving = None
7776

7877
def load(self, model_path: str = None):
7978
"""Placeholder docstring"""
@@ -161,6 +160,19 @@ def create_server(
161160
env_vars=env_vars if env_vars else self.env_vars,
162161
)
163162
self._ping_container = self._tensorflow_serving_deep_ping
163+
elif self.model_server == ModelServer.TEI:
164+
tei_serving = LocalTeiServing()
165+
tei_serving._start_tei_serving(
166+
client=self.client,
167+
image=image,
168+
model_path=model_path if model_path else self.model_path,
169+
secret_key=secret_key,
170+
env_vars=env_vars if env_vars else self.env_vars,
171+
)
172+
tei_serving.schema_builder = self.schema_builder
173+
self.container = tei_serving.container
174+
self._ping_container = tei_serving._tei_deep_ping
175+
self._invoke_serving = tei_serving._invoke_tei_serving
164176
elif self.model_server == ModelServer.FASTAPI:
165177
self._start_fast_api(
166178
client=self.client,
@@ -171,7 +183,6 @@ def create_server(
171183
)
172184
self._ping_container = self._fastapi_deep_ping
173185

174-
175186
# allow some time for container to be ready
176187
time.sleep(10)
177188

src/sagemaker/serve/mode/sagemaker_endpoint_mode.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
from typing import Type
88

9+
from sagemaker.serve.model_server.tei.server import SageMakerTeiServing
910
from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing
1011
from sagemaker.session import Session
1112
from sagemaker.serve.utils.types import ModelServer
@@ -39,6 +40,8 @@ def __init__(self, inference_spec: Type[InferenceSpec], model_server: ModelServe
3940
self.inference_spec = inference_spec
4041
self.model_server = model_server
4142

43+
self._tei_serving = SageMakerTeiServing()
44+
4245
def load(self, model_path: str):
4346
"""Placeholder docstring"""
4447
path = Path(model_path)
@@ -68,8 +71,9 @@ def prepare(
6871
+ "session to be created or supply `sagemaker_session` into @serve.invoke."
6972
) from e
7073

74+
upload_artifacts = None
7175
if self.model_server == ModelServer.TORCHSERVE:
72-
return self._upload_torchserve_artifacts(
76+
upload_artifacts = self._upload_torchserve_artifacts(
7377
model_path=model_path,
7478
sagemaker_session=sagemaker_session,
7579
secret_key=secret_key,
@@ -78,7 +82,7 @@ def prepare(
7882
)
7983

8084
if self.model_server == ModelServer.TRITON:
81-
return self._upload_triton_artifacts(
85+
upload_artifacts = self._upload_triton_artifacts(
8286
model_path=model_path,
8387
sagemaker_session=sagemaker_session,
8488
secret_key=secret_key,
@@ -87,15 +91,15 @@ def prepare(
8791
)
8892

8993
if self.model_server == ModelServer.DJL_SERVING:
90-
return self._upload_djl_artifacts(
94+
upload_artifacts = self._upload_djl_artifacts(
9195
model_path=model_path,
9296
sagemaker_session=sagemaker_session,
9397
s3_model_data_url=s3_model_data_url,
9498
image=image,
9599
)
96100

97101
if self.model_server == ModelServer.TGI:
98-
return self._upload_tgi_artifacts(
102+
upload_artifacts = self._upload_tgi_artifacts(
99103
model_path=model_path,
100104
sagemaker_session=sagemaker_session,
101105
s3_model_data_url=s3_model_data_url,
@@ -113,20 +117,31 @@ def prepare(
113117
)
114118

115119
if self.model_server == ModelServer.MMS:
116-
return self._upload_server_artifacts(
120+
upload_artifacts = self._upload_server_artifacts(
117121
model_path=model_path,
118122
sagemaker_session=sagemaker_session,
119123
s3_model_data_url=s3_model_data_url,
120124
image=image,
121125
)
122126

123127
if self.model_server == ModelServer.TENSORFLOW_SERVING:
124-
return self._upload_tensorflow_serving_artifacts(
128+
upload_artifacts = self._upload_tensorflow_serving_artifacts(
125129
model_path=model_path,
126130
sagemaker_session=sagemaker_session,
127131
secret_key=secret_key,
128132
s3_model_data_url=s3_model_data_url,
129133
image=image,
130134
)
131135

136+
if self.model_server == ModelServer.TEI:
137+
upload_artifacts = self._tei_serving._upload_tei_artifacts(
138+
model_path=model_path,
139+
sagemaker_session=sagemaker_session,
140+
s3_model_data_url=s3_model_data_url,
141+
image=image,
142+
)
143+
144+
if upload_artifacts:
145+
return upload_artifacts
146+
132147
raise ValueError("%s model server is not supported" % self.model_server)

src/sagemaker/serve/model_server/tei/__init__.py

Whitespace-only changes.
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""Module for Local TEI Serving"""
2+
3+
from __future__ import absolute_import
4+
5+
import requests
6+
import logging
7+
from pathlib import Path
8+
from docker.types import DeviceRequest
9+
from sagemaker import Session, fw_utils
10+
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
11+
from sagemaker.base_predictor import PredictorBase
12+
from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join
13+
from sagemaker.s3 import S3Uploader
14+
from sagemaker.local.utils import get_docker_host
15+
16+
17+
MODE_DIR_BINDING = "/opt/ml/model/"
18+
_SHM_SIZE = "2G"
19+
_DEFAULT_ENV_VARS = {
20+
"TRANSFORMERS_CACHE": "/opt/ml/model/",
21+
"HUGGINGFACE_HUB_CACHE": "/opt/ml/model/",
22+
}
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
class LocalTeiServing:
28+
"""LocalTeiServing class"""
29+
30+
def _start_tei_serving(
31+
self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict
32+
):
33+
"""Starts a local tei serving container.
34+
35+
Args:
36+
client: Docker client
37+
image: Image to use
38+
model_path: Path to the model
39+
secret_key: Secret key to use for authentication
40+
env_vars: Environment variables to set
41+
"""
42+
if env_vars and secret_key:
43+
env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key
44+
45+
self.container = client.containers.run(
46+
image,
47+
shm_size=_SHM_SIZE,
48+
device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])],
49+
network_mode="host",
50+
detach=True,
51+
auto_remove=True,
52+
volumes={
53+
Path(model_path).joinpath("code"): {
54+
"bind": MODE_DIR_BINDING,
55+
"mode": "rw",
56+
},
57+
},
58+
environment=_update_env_vars(env_vars),
59+
)
60+
61+
def _invoke_tei_serving(self, request: object, content_type: str, accept: str):
62+
"""Invokes a local tei serving container.
63+
64+
Args:
65+
request: Request to send
66+
content_type: Content type to use
67+
accept: Accept to use
68+
"""
69+
try:
70+
response = requests.post(
71+
f"http://{get_docker_host()}:8080/invocations",
72+
data=request,
73+
headers={"Content-Type": content_type, "Accept": accept},
74+
timeout=600,
75+
)
76+
response.raise_for_status()
77+
return response.content
78+
except Exception as e:
79+
raise Exception("Unable to send request to the local container server") from e
80+
81+
def _tei_deep_ping(self, predictor: PredictorBase):
82+
"""Checks if the local tei serving container is up and running.
83+
84+
If the container is not up and running, it will raise an exception.
85+
"""
86+
response = None
87+
try:
88+
response = predictor.predict(self.schema_builder.sample_input)
89+
return (True, response)
90+
# pylint: disable=broad-except
91+
except Exception as e:
92+
if "422 Client Error: Unprocessable Entity for url" in str(e):
93+
raise LocalModelInvocationException(str(e))
94+
return (False, response)
95+
96+
return (True, response)
97+
98+
99+
class SageMakerTeiServing:
100+
"""SageMakerTeiServing class"""
101+
102+
def _upload_tei_artifacts(
103+
self,
104+
model_path: str,
105+
sagemaker_session: Session,
106+
s3_model_data_url: str = None,
107+
image: str = None,
108+
env_vars: dict = None,
109+
):
110+
"""Uploads the model artifacts to S3.
111+
112+
Args:
113+
model_path: Path to the model
114+
sagemaker_session: SageMaker session
115+
s3_model_data_url: S3 model data URL
116+
image: Image to use
117+
env_vars: Environment variables to set
118+
"""
119+
if s3_model_data_url:
120+
bucket, key_prefix = parse_s3_url(url=s3_model_data_url)
121+
else:
122+
bucket, key_prefix = None, None
123+
124+
code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image)
125+
126+
bucket, code_key_prefix = determine_bucket_and_prefix(
127+
bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session
128+
)
129+
130+
code_dir = Path(model_path).joinpath("code")
131+
132+
s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code")
133+
134+
logger.debug("Uploading TEI Model Resources uncompressed to: %s", s3_location)
135+
136+
model_data_url = S3Uploader.upload(
137+
str(code_dir),
138+
s3_location,
139+
None,
140+
sagemaker_session,
141+
)
142+
143+
model_data = {
144+
"S3DataSource": {
145+
"CompressionType": "None",
146+
"S3DataType": "S3Prefix",
147+
"S3Uri": model_data_url + "/",
148+
}
149+
}
150+
151+
return (model_data, _update_env_vars(env_vars))
152+
153+
154+
def _update_env_vars(env_vars: dict) -> dict:
155+
"""Placeholder docstring"""
156+
updated_env_vars = {}
157+
updated_env_vars.update(_DEFAULT_ENV_VARS)
158+
if env_vars:
159+
updated_env_vars.update(env_vars)
160+
return updated_env_vars

0 commit comments

Comments
 (0)