-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Feat: Pull latest tei container for sentence similiarity models on HuggingFace hub #4686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 21 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
2e00238
Update: Pull latest tei container for sentence similiarity models
samruds 43ce1ba
Fix formatting
samruds 6211227
Address PR comments
samruds 0441436
Fix formatting
samruds 4973f8f
Fix check
samruds f8cd864
Switch sentence similarity to be deployed on tgi
samruds a5fa0e9
Fix formatting
samruds e524134
Fix formatting
samruds 4263a44
Fix formatting
samruds eb3b6d3
Fix formatting
samruds 2b9ba2a
Introduce TEI builder with TGI server
samruds 33d5b04
Fix formmatting
samruds 20687f0
Add integ test
samruds d85425f
Fix formatting
samruds bbdff4c
Add integ test
samruds a526416
Add integ test
samruds 1e49f88
Add integ test
samruds af78426
Add integ test
samruds a5e665a
Add integ test
samruds e58f622
Fix formatting
samruds 4c336dd
Merge branch 'master' into master
samruds ea900bf
Move to G5 for integ test
samruds cffe46a
Fix formatting
samruds 48205ad
Integ test updates
samruds 312d837
Integ test updates
samruds 29ea1c5
Integ test updates
samruds f6f8116
Fix formatting
samruds 166e570
Integ test updates
samruds 4bb5522
Move back to generate for ping
samruds 17645f7
Integ test updates
samruds e8341c2
Integ test updates
samruds File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Holds mixin logic to support deployment of Model ID""" | ||
from __future__ import absolute_import | ||
import logging | ||
from typing import Type | ||
from abc import ABC, abstractmethod | ||
|
||
from sagemaker.model import Model | ||
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf | ||
|
||
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri | ||
from sagemaker.serve.utils.local_hardware import ( | ||
_get_nb_instance, | ||
) | ||
from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure | ||
from sagemaker.serve.utils.predictors import TgiLocalModePredictor | ||
from sagemaker.serve.utils.types import ModelServer | ||
from sagemaker.serve.mode.function_pointers import Mode | ||
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry | ||
from sagemaker.base_predictor import PredictorBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
_CODE_FOLDER = "code" | ||
|
||
|
||
class TEI(ABC): | ||
"""TEI build logic for ModelBuilder()""" | ||
|
||
def __init__(self): | ||
self.model = None | ||
self.serve_settings = None | ||
self.sagemaker_session = None | ||
self.model_path = None | ||
self.dependencies = None | ||
self.modes = None | ||
self.mode = None | ||
self.model_server = None | ||
self.image_uri = None | ||
self._is_custom_image_uri = False | ||
self.image_config = None | ||
self.vpc_config = None | ||
self._original_deploy = None | ||
self.hf_model_config = None | ||
self._default_tensor_parallel_degree = None | ||
self._default_data_type = None | ||
self._default_max_tokens = None | ||
self.pysdk_model = None | ||
self.schema_builder = None | ||
self.env_vars = None | ||
self.nb_instance_type = None | ||
self.ram_usage_model_load = None | ||
self.secret_key = None | ||
self.jumpstart = None | ||
self.role_arn = None | ||
|
||
@abstractmethod | ||
def _prepare_for_mode(self): | ||
"""Placeholder docstring""" | ||
|
||
@abstractmethod | ||
def _get_client_translators(self): | ||
"""Placeholder docstring""" | ||
|
||
def _set_to_tgi(self): | ||
"""Placeholder docstring""" | ||
if self.model_server != ModelServer.TGI: | ||
messaging = ( | ||
"HuggingFace Model ID support on model server: " | ||
f"{self.model_server} is not currently supported. " | ||
f"Defaulting to {ModelServer.TGI}" | ||
) | ||
logger.warning(messaging) | ||
self.model_server = ModelServer.TGI | ||
|
||
def _create_tei_model(self) -> Type[Model]: | ||
"""Placeholder docstring""" | ||
if not self.image_uri: | ||
self.image_uri = get_huggingface_llm_image_uri( | ||
"huggingface-tei", session=self.sagemaker_session | ||
) | ||
samruds marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
pysdk_model = HuggingFaceModel( | ||
image_uri=self.image_uri, | ||
image_config=self.image_config, | ||
vpc_config=self.vpc_config, | ||
env=self.env_vars, | ||
role=self.role_arn, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
|
||
logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri) | ||
|
||
self._original_deploy = pysdk_model.deploy | ||
pysdk_model.deploy = self._tei_model_builder_deploy_wrapper | ||
return pysdk_model | ||
|
||
@_capture_telemetry("tei.deploy") | ||
def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: | ||
"""Placeholder docstring""" | ||
timeout = kwargs.get("model_data_download_timeout") | ||
if timeout: | ||
self.pysdk_model.env.update({"MODEL_LOADING_TIMEOUT": str(timeout)}) | ||
|
||
if "mode" in kwargs and kwargs.get("mode") != self.mode: | ||
overwrite_mode = kwargs.get("mode") | ||
# mode overwritten by customer during model.deploy() | ||
logger.warning( | ||
"Deploying in %s Mode, overriding existing configurations set for %s mode", | ||
overwrite_mode, | ||
self.mode, | ||
) | ||
|
||
if overwrite_mode == Mode.SAGEMAKER_ENDPOINT: | ||
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT | ||
elif overwrite_mode == Mode.LOCAL_CONTAINER: | ||
self._prepare_for_mode() | ||
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER | ||
else: | ||
raise ValueError("Mode %s is not supported!" % overwrite_mode) | ||
|
||
serializer = self.schema_builder.input_serializer | ||
deserializer = self.schema_builder._output_deserializer | ||
if self.mode == Mode.LOCAL_CONTAINER: | ||
timeout = kwargs.get("model_data_download_timeout") | ||
|
||
predictor = TgiLocalModePredictor( | ||
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer | ||
) | ||
|
||
self.modes[str(Mode.LOCAL_CONTAINER)].create_server( | ||
self.image_uri, | ||
timeout if timeout else 1800, | ||
None, | ||
predictor, | ||
self.pysdk_model.env, | ||
jumpstart=False, | ||
) | ||
|
||
return predictor | ||
|
||
if "mode" in kwargs: | ||
del kwargs["mode"] | ||
if "role" in kwargs: | ||
self.pysdk_model.role = kwargs.get("role") | ||
del kwargs["role"] | ||
|
||
# set model_data to uncompressed s3 dict | ||
self.pysdk_model.model_data, env_vars = self._prepare_for_mode() | ||
self.env_vars.update(env_vars) | ||
self.pysdk_model.env.update(self.env_vars) | ||
|
||
if "endpoint_logging" not in kwargs: | ||
kwargs["endpoint_logging"] = True | ||
|
||
if self.nb_instance_type and "instance_type" not in kwargs: | ||
kwargs.update({"instance_type": self.nb_instance_type}) | ||
elif not self.nb_instance_type and "instance_type" not in kwargs: | ||
raise ValueError( | ||
"Instance type must be provided when deploying " "to SageMaker Endpoint mode." | ||
) | ||
|
||
if "initial_instance_count" not in kwargs: | ||
kwargs.update({"initial_instance_count": 1}) | ||
|
||
predictor = self._original_deploy(*args, **kwargs) | ||
|
||
predictor.serializer = serializer | ||
predictor.deserializer = deserializer | ||
return predictor | ||
|
||
def _build_for_hf_tei(self): | ||
"""Placeholder docstring""" | ||
self.nb_instance_type = _get_nb_instance() | ||
|
||
_create_dir_structure(self.model_path) | ||
if not hasattr(self, "pysdk_model"): | ||
self.env_vars.update({"HF_MODEL_ID": self.model}) | ||
self.hf_model_config = _get_model_config_properties_from_hf( | ||
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") | ||
) | ||
|
||
self.pysdk_model = self._create_tei_model() | ||
|
||
if self.mode == Mode.LOCAL_CONTAINER: | ||
self._prepare_for_mode() | ||
|
||
return self.pysdk_model | ||
|
||
def _build_for_tei(self): | ||
"""Placeholder docstring""" | ||
self.secret_key = None | ||
|
||
self._set_to_tgi() | ||
|
||
self.pysdk_model = self._build_for_hf_tei() | ||
return self.pysdk_model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import pytest | ||
from sagemaker.serve.builder.schema_builder import SchemaBuilder | ||
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode | ||
import tests.integ | ||
from tests.integ.sagemaker.serve.constants import ( | ||
HF_DIR, | ||
PYTHON_VERSION_IS_NOT_310, | ||
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, | ||
) | ||
|
||
from tests.integ.timeout import timeout | ||
from tests.integ.utils import cleanup_model_resources, gpu_list, retry_with_instance_list | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
sample_input = { | ||
"inputs": "The man worked as a [MASK].", | ||
} | ||
|
||
loaded_response = [ | ||
{ | ||
"score": 0.0974755585193634, | ||
"token": 10533, | ||
"token_str": "carpenter", | ||
"sequence": "the man worked as a carpenter.", | ||
}, | ||
{ | ||
"score": 0.052383411675691605, | ||
"token": 15610, | ||
"token_str": "waiter", | ||
"sequence": "the man worked as a waiter.", | ||
}, | ||
{ | ||
"score": 0.04962712526321411, | ||
"token": 13362, | ||
"token_str": "barber", | ||
"sequence": "the man worked as a barber.", | ||
}, | ||
{ | ||
"score": 0.0378861166536808, | ||
"token": 15893, | ||
"token_str": "mechanic", | ||
"sequence": "the man worked as a mechanic.", | ||
}, | ||
{ | ||
"score": 0.037680838257074356, | ||
"token": 18968, | ||
"token_str": "salesman", | ||
"sequence": "the man worked as a salesman.", | ||
}, | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def model_input(): | ||
return {"inputs": "The man worked as a [MASK]."} | ||
|
||
|
||
@pytest.fixture | ||
def model_builder_model_schema_builder(): | ||
return ModelBuilder( | ||
model_path=HF_DIR, | ||
model="bert-base-uncased", | ||
schema_builder=SchemaBuilder(sample_input, loaded_response), | ||
model_metadata={ | ||
"HF_TASK": "sentence-similarity", | ||
}, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def model_builder(request): | ||
return request.getfixturevalue(request.param) | ||
|
||
|
||
@pytest.mark.skipif( | ||
PYTHON_VERSION_IS_NOT_310, | ||
tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS | ||
and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS, | ||
reason="no ml.p2 or ml.p3 instances in this region", | ||
) | ||
@retry_with_instance_list(gpu_list(tests.integ.test_region())) | ||
@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True) | ||
def test_tei_sagemaker_endpoint(sagemaker_session, model_builder, model_input, **kwargs): | ||
logger.info("Running in SAGEMAKER_ENDPOINT mode...") | ||
caught_ex = None | ||
|
||
iam_client = sagemaker_session.boto_session.client("iam") | ||
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] | ||
|
||
model = model_builder.build( | ||
mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session | ||
) | ||
|
||
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): | ||
try: | ||
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") | ||
predictor = model.deploy( | ||
instance_type=kwargs["instance_type"], initial_instance_count=2 | ||
) | ||
logger.info("Endpoint successfully deployed.") | ||
predictor.predict(model_input) | ||
assert predictor is not None | ||
except Exception as e: | ||
caught_ex = e | ||
finally: | ||
cleanup_model_resources( | ||
sagemaker_session=model_builder.sagemaker_session, | ||
model_name=model.name, | ||
endpoint_name=model.endpoint_name, | ||
) | ||
if caught_ex: | ||
logger.exception(caught_ex) | ||
assert False, f"{caught_ex} was thrown when running tei sagemaker endpoint test" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.