Skip to content

Commit ad8a5f6

Browse files
committed
Update: Pull latest tei container for sentence similiarity models
1 parent 65cc586 commit ad8a5f6

File tree

2 files changed

+72
-3
lines changed

2 files changed

+72
-3
lines changed

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
_get_nb_instance,
2424
)
2525
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
26-
from sagemaker.huggingface import HuggingFaceModel
26+
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
2727
from sagemaker.serve.model_server.multi_model_server.prepare import (
2828
_create_dir_structure,
2929
)
@@ -47,6 +47,7 @@ class Transformers(ABC):
4747
"""Transformers build logic with ModelBuilder()"""
4848

4949
def __init__(self):
50+
self.model_metadata = None
5051
self.model = None
5152
self.serve_settings = None
5253
self.sagemaker_session = None
@@ -99,7 +100,26 @@ def _create_transformers_model(self) -> Type[Model]:
99100
if hf_model_md is None:
100101
raise ValueError("Could not fetch HF metadata")
101102

102-
if "pytorch" in hf_model_md.get("tags"):
103+
model_task = None
104+
if self.model_metadata:
105+
model_task = self.model_metadata.get("HF_TASK")
106+
else:
107+
model_task = hf_model_md.get("pipeline_tag")
108+
109+
if model_task == "sentence-similarity" and not self.image_uri:
110+
self.image_uri = \
111+
get_huggingface_llm_image_uri("huggingface-tei", session=self.sagemaker_session)
112+
113+
logger.info("Auto detected %s. Proceeding with the the deployment.", self.image_uri)
114+
115+
pysdk_model = HuggingFaceModel(
116+
env=self.env_vars,
117+
role=self.role_arn,
118+
sagemaker_session=self.sagemaker_session,
119+
image_uri=self.image_uri,
120+
vpc_config=self.vpc_config,
121+
)
122+
elif "pytorch" in hf_model_md.get("tags"):
103123
self.pytorch_version = self._get_supported_version(
104124
hf_config, base_hf_version, "pytorch"
105125
)

tests/unit/sagemaker/serve/builder/test_transformers_builder.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_build_deploy_for_transformers_local_container_and_remote_container(
110110
return_value="ml.g5.24xlarge",
111111
)
112112
@patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None)
113-
def test_image_uri(
113+
def test_image_uri_override(
114114
self,
115115
mock_get_nb_instance,
116116
mock_telemetry,
@@ -144,3 +144,52 @@ def test_image_uri(
144144

145145
with self.assertRaises(ValueError) as _:
146146
model.deploy(mode=Mode.IN_PROCESS)
147+
148+
@patch(
149+
"sagemaker.serve.builder.transformers_builder._get_nb_instance",
150+
return_value="ml.g5.24xlarge",
151+
)
152+
@patch(
153+
"sagemaker.huggingface.llm_utils.get_huggingface_model_metadata",
154+
return_value="sentence-similarity",
155+
)
156+
@patch(
157+
"from sagemaker.huggingface.get_huggingface_llm_image_uri",
158+
return_value=MOCK_IMAGE_CONFIG
159+
)
160+
@patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None)
161+
def test_sentence_similarity_support(
162+
self,
163+
mock_get_nb_instance,
164+
mock_task,
165+
mock_image,
166+
mock_telemetry,
167+
):
168+
builder = ModelBuilder(
169+
model=mock_model_id,
170+
schema_builder=mock_schema_builder,
171+
mode=Mode.LOCAL_CONTAINER,
172+
)
173+
174+
builder._prepare_for_mode = MagicMock()
175+
builder._prepare_for_mode.side_effect = None
176+
177+
model = builder.build()
178+
builder.serve_settings.telemetry_opt_out = True
179+
180+
builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock()
181+
predictor = model.deploy(model_data_download_timeout=1800)
182+
183+
assert builder.image_uri == MOCK_IMAGE_CONFIG
184+
assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800"
185+
assert isinstance(predictor, TransformersLocalModePredictor)
186+
187+
assert builder.nb_instance_type == "ml.g5.24xlarge"
188+
189+
builder._original_deploy = MagicMock()
190+
builder._prepare_for_mode.return_value = (None, {})
191+
predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn")
192+
assert "HF_MODEL_ID" in model.env
193+
194+
with self.assertRaises(ValueError) as _:
195+
model.deploy(mode=Mode.IN_PROCESS)

0 commit comments

Comments
 (0)