diff --git a/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py b/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py index 313f956f..e1f1dce4 100644 --- a/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py +++ b/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py @@ -43,12 +43,13 @@ def _is_model_file(filename): is_model_file = ext in [".pt", ".pth"] return is_model_file - def default_model_fn(self, model_dir): + def default_model_fn(self, model_dir, context=None): """Loads a model. For PyTorch, a default function to load a model only if Elastic Inference is used. In other cases, users should provide customized model_fn() in script. Args: model_dir: a directory where model is saved. + context: context for the request. Returns: A PyTorch model. """ @@ -65,7 +66,12 @@ def default_model_fn(self, model_dir): "Failed to load {}. Please ensure model is saved using torchscript.".format(model_path) ) from e else: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if context: + properties = context.system_properties + device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model_path = os.path.join(model_dir, DEFAULT_MODEL_FILENAME) if not os.path.exists(model_path): model_files = [file for file in os.listdir(model_dir) if self._is_model_file(file)] @@ -83,29 +89,35 @@ def default_model_fn(self, model_dir): model = model.to(device) return model - def default_input_fn(self, input_data, content_type): + def default_input_fn(self, input_data, content_type, context=None): """A default input_fn that can handle JSON, CSV and NPZ formats. Args: input_data: the request payload serialized in the content_type format content_type: the request content_type + context: context for the request Returns: input_data deserialized into torch.FloatTensor or torch.cuda.FloatTensor, depending if cuda is available. """ - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if context: + properties = context.system_properties + device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") np_array = decoder.decode(input_data, content_type) tensor = torch.FloatTensor( np_array) if content_type in content_types.UTF8_TYPES else torch.from_numpy(np_array) return tensor.to(device) - def default_predict_fn(self, data, model): + def default_predict_fn(self, data, model, context=None): """A default predict_fn for PyTorch. Calls a model on data deserialized in input_fn. Runs prediction on GPU if cuda is available. Args: data: input data (torch.Tensor) for prediction deserialized by input_fn model: PyTorch model loaded in memory by model_fn + context: context for the request Returns: a prediction """ @@ -118,7 +130,12 @@ def default_predict_fn(self, data, model): with torch.jit.optimized_execution(True, {"target_device": "eia:0"}): output = model(input_data) else: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if context: + properties = context.system_properties + device = torch.device("cuda:" + str(properties.get("gpu_id")) + if torch.cuda.is_available() else "cpu") + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) input_data = data.to(device) model.eval() diff --git a/src/sagemaker_pytorch_serving_container/handler_service.py b/src/sagemaker_pytorch_serving_container/handler_service.py index 07e81c7b..6cce4ca8 100644 --- a/src/sagemaker_pytorch_serving_container/handler_service.py +++ b/src/sagemaker_pytorch_serving_container/handler_service.py @@ -13,32 +13,30 @@ from __future__ import absolute_import from sagemaker_inference.default_handler_service import DefaultHandlerService -from sagemaker_inference.transformer import Transformer -from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler +from sagemaker_pytorch_serving_container.transformer import PyTorchTransformer import os import sys +PYTHON_PATH_ENV = "PYTHONPATH" ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true" class HandlerService(DefaultHandlerService): + """ + Handler service that is executed by the model server. - """Handler service that is executed by the model server. - - Determines specific default inference handlers to use based on the type MXNet model being used. + Determines specific default inference handlers to use based on the type pytorch model being used. This class extends ``DefaultHandlerService``, which define the following: - The ``handle`` method is invoked for all incoming inference requests to the model server. - The ``initialize`` method is invoked at model server start up. - - Based on: https://github.com/awslabs/mxnet-model-server/blob/master/docs/custom_service.md - """ + def __init__(self): self._initialized = False - transformer = Transformer(default_inference_handler=DefaultPytorchInferenceHandler()) + transformer = PyTorchTransformer() super(HandlerService, self).__init__(transformer=transformer) def initialize(self, context): @@ -48,4 +46,14 @@ def initialize(self, context): sys.path.append(code_dir) self._initialized = True - super().initialize(context) + properties = context.system_properties + model_dir = properties.get("model_dir") + + # add model_dir/code to python path + code_dir_path = "{}:".format(model_dir + "/code") + if PYTHON_PATH_ENV in os.environ: + os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV] + else: + os.environ[PYTHON_PATH_ENV] = code_dir_path + + self._service.validate_and_initialize(model_dir=model_dir, context=context) diff --git a/src/sagemaker_pytorch_serving_container/transformer.py b/src/sagemaker_pytorch_serving_container/transformer.py new file mode 100644 index 00000000..6da834c5 --- /dev/null +++ b/src/sagemaker_pytorch_serving_container/transformer.py @@ -0,0 +1,141 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import absolute_import + +import traceback + +from six.moves import http_client +from sagemaker_inference.transformer import Transformer +from sagemaker_inference import content_types, environment, utils +from sagemaker_inference.errors import BaseInferenceToolkitError, GenericInferenceToolkitError +from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler + + +class PyTorchTransformer(Transformer): + """Represents the execution workflow for handling pytorch inference requests + sent to the model server. + """ + def __init__(self, default_inference_handler=DefaultPytorchInferenceHandler()): + super().__init__(default_inference_handler) + self._context = None + + def transform(self, data, context): + """Take a request with input data, deserialize it, make a prediction, and return a + serialized response. + Args: + data (obj): the request data. + context (obj): metadata on the incoming request data. + Returns: + list[obj]: The serialized prediction result wrapped in a list if + inference is successful. Otherwise returns an error message + with the context set appropriately. + """ + + try: + properties = context.system_properties + model_dir = properties.get("model_dir") + self.validate_and_initialize(model_dir=model_dir, context=context) + + response_list = [] + for i in range(len(data)): + input_data = data[i].get("body") + + request_processor = context.request_processor[0] + + request_property = request_processor.get_request_properties() + content_type = utils.retrieve_content_type_header(request_property) + accept = request_property.get("Accept") or request_property.get("accept") + + if not accept or accept == content_types.ANY: + accept = self._environment.default_accept + + if content_type in content_types.UTF8_TYPES: + input_data = input_data.decode("utf-8") + + result = self._run_handle_function(self._transform_fn, *(self._model, input_data, content_type, accept)) + + response = result + response_content_type = accept + + if isinstance(result, tuple): + # handles tuple for backwards compatibility + response = result[0] + response_content_type = result[1] + + context.set_response_content_type(0, response_content_type) + + response_list.append(response) + + return response_list + except Exception as e: # pylint: disable=broad-except + trace = traceback.format_exc() + if isinstance(e, BaseInferenceToolkitError): + return super().handle_error(context, e, trace) + else: + return super().handle_error( + context, + GenericInferenceToolkitError(http_client.INTERNAL_SERVER_ERROR, str(e)), + trace, + ) + + def validate_and_initialize(self, model_dir=environment.model_dir, context=None): + """Validates the user module against the SageMaker inference contract. + Load the model as defined by the ``model_fn`` to prepare handling predictions. + """ + if not self._initialized: + self._context = context + self._environment = environment.Environment() + self._validate_user_module_and_set_functions() + + if self._pre_model_fn is not None: + self._run_handle_function(self._pre_model_fn, *(model_dir, )) + + self._model = self._run_handle_function(self._model_fn, *(model_dir, )) + + if self._model_warmup_fn is not None: + self._run_handle_function(self._model_warmup_fn, *(model_dir, self._model)) + + self._initialized = True + + def _default_transform_fn(self, model, input_data, content_type, accept): + """Make predictions against the model and return a serialized response. + This serves as the default implementation of transform_fn, used when the + user has not provided an implementation. + Args: + model (obj): model loaded by model_fn. + input_data (obj): the request data. + content_type (str): the request content type. + accept (str): accept header expected by the client. + Returns: + obj: the serialized prediction result or a tuple of the form + (response_data, content_type) + """ + data = self._run_handle_function(self._input_fn, *(input_data, content_type)) + prediction = self._run_handle_function(self._predict_fn, *(data, model)) + result = self._run_handle_function(self._output_fn, *(prediction, accept)) + + return result + + def _run_handle_function(self, func, *argv): + """Wrapper to call the handle function which covers 2 cases: + 1. context passed to the handle function + 2. context not passed to the handle function + """ + try: + argv_context = argv + (self._context, ) + result = func(*argv_context) + except TypeError: + result = func(*argv) + + return result diff --git a/test/unit/test_handler_service.py b/test/unit/test_handler_service.py index fd3dfc60..8be732ec 100644 --- a/test/unit/test_handler_service.py +++ b/test/unit/test_handler_service.py @@ -16,18 +16,18 @@ @patch('sagemaker_pytorch_serving_container.default_pytorch_inference_handler.DefaultPytorchInferenceHandler') -@patch('sagemaker_inference.transformer.Transformer') -def test_hosting_start(Transformer, DefaultPytorchInferenceHandler): +@patch('sagemaker_pytorch_serving_container.transformer.PyTorchTransformer') +def test_hosting_start(PyTorchTransformer, DefaultPytorchInferenceHandler): from sagemaker_pytorch_serving_container import handler_service handler_service.HandlerService() - Transformer.assert_called_with(default_inference_handler=DefaultPytorchInferenceHandler()) + PyTorchTransformer.assert_called_with() @patch('sagemaker_pytorch_serving_container.default_pytorch_inference_handler.DefaultPytorchInferenceHandler') -@patch('sagemaker_inference.transformer.Transformer') -def test_hosting_start_enable_multi_model(Transformer, DefaultPytorchInferenceHandler): +@patch('sagemaker_pytorch_serving_container.transformer.PyTorchTransformer') +def test_hosting_start_enable_multi_model(PyTorchTransformer, DefaultPytorchInferenceHandler): from sagemaker_pytorch_serving_container import handler_service context = Mock() diff --git a/test/unit/test_transformer.py b/test/unit/test_transformer.py new file mode 100644 index 00000000..7b304acd --- /dev/null +++ b/test/unit/test_transformer.py @@ -0,0 +1,306 @@ +# Copyright 2019-2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import absolute_import + +from mock import Mock, patch +import pytest + +try: + import http.client as http_client +except ImportError: + import httplib as http_client + +from sagemaker_inference import content_types, environment +from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler +from sagemaker_inference.errors import BaseInferenceToolkitError +from sagemaker_pytorch_serving_container.transformer import PyTorchTransformer + + +INPUT_DATA = "input_data" +CONTENT_TYPE = "content_type" +ACCEPT = "accept" +DEFAULT_ACCEPT = "default_accept" +RESULT = "result" +MODEL = "foo" + +PREPROCESSED_DATA = "preprocessed_data" +PREDICT_RESULT = "prediction_result" +PROCESSED_RESULT = "processed_result" + + +def test_default_transformer(): + transformer = PyTorchTransformer() + + assert isinstance(transformer._default_inference_handler, DefaultPytorchInferenceHandler) + assert transformer._initialized is False + assert transformer._environment is None + assert transformer._pre_model_fn is None + assert transformer._model_warmup_fn is None + assert transformer._model is None + assert transformer._model_fn is None + assert transformer._transform_fn is None + assert transformer._input_fn is None + assert transformer._predict_fn is None + assert transformer._output_fn is None + assert transformer._context is None + + +def test_transformer_with_custom_default_inference_handler(): + default_inference_handler = Mock() + + transformer = PyTorchTransformer(default_inference_handler) + + assert transformer._default_inference_handler == default_inference_handler + assert transformer._initialized is False + assert transformer._environment is None + assert transformer._pre_model_fn is None + assert transformer._model_warmup_fn is None + assert transformer._model is None + assert transformer._model_fn is None + assert transformer._transform_fn is None + assert transformer._input_fn is None + assert transformer._predict_fn is None + assert transformer._output_fn is None + assert transformer._context is None + + +@pytest.mark.parametrize("accept_key", ["Accept", "accept"]) +@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) +@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize") +def test_transform(validate, retrieve_content_type_header, accept_key): + data = [{"body": INPUT_DATA}] + context = Mock() + request_processor = Mock() + transform_fn = Mock(return_value=RESULT) + + context.request_processor = [request_processor] + request_property = {accept_key: ACCEPT} + request_processor.get_request_properties.return_value = request_property + + transformer = PyTorchTransformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._context = context + + result = transformer.transform(data, context) + + validate.assert_called_once() + retrieve_content_type_header.assert_called_once_with(request_property) + transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, context) + context.set_response_content_type.assert_called_once_with(0, ACCEPT) + assert isinstance(result, list) + assert result[0] == RESULT + + +@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) +@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize") +def test_transform_any_accept(validate, retrieve_content_type_header): + data = [{"body": INPUT_DATA}] + context = Mock() + request_processor = Mock() + transform_fn = Mock() + environment = Mock() + environment.default_accept = DEFAULT_ACCEPT + + context.request_processor = [request_processor] + request_processor.get_request_properties.return_value = {"accept": content_types.ANY} + + transformer = PyTorchTransformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._environment = environment + transformer._context = context + + transformer.transform(data, context) + + validate.assert_called_once() + transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, DEFAULT_ACCEPT, context) + + +@pytest.mark.parametrize("content_type", content_types.UTF8_TYPES) +@patch("sagemaker_inference.utils.retrieve_content_type_header") +@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize") +def test_transform_decode(validate, retrieve_content_type_header, content_type): + input_data = Mock() + context = Mock() + request_processor = Mock() + transform_fn = Mock() + data = [{"body": input_data}] + + input_data.decode.return_value = INPUT_DATA + context.request_processor = [request_processor] + request_processor.get_request_properties.return_value = {"accept": ACCEPT} + retrieve_content_type_header.return_value = content_type + + transformer = PyTorchTransformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._context = context + + transformer.transform(data, context) + + input_data.decode.assert_called_once_with("utf-8") + transform_fn.assert_called_once_with(MODEL, INPUT_DATA, content_type, ACCEPT, context) + + +@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) +@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize") +def test_transform_tuple(validate, retrieve_content_type_header): + data = [{"body": INPUT_DATA}] + context = Mock() + request_processor = Mock() + transform_fn = Mock(return_value=(RESULT, ACCEPT)) + + context.request_processor = [request_processor] + request_processor.get_request_properties.return_value = {"accept": ACCEPT} + + transformer = PyTorchTransformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._context = context + + result = transformer.transform(data, context) + + transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, context) + context.set_response_content_type.assert_called_once_with(0, transform_fn()[1]) + assert isinstance(result, list) + assert result[0] == transform_fn()[0] + + +@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer._validate_user_module_and_set_functions") +@patch("sagemaker_inference.environment.Environment") +def test_validate_and_initialize(env, validate_user_module): + transformer = PyTorchTransformer() + + model_fn = Mock() + context = Mock() + transformer._model_fn = model_fn + + assert transformer._initialized is False + assert transformer._context is None + + transformer.validate_and_initialize(context=context) + + assert transformer._initialized is True + assert transformer._context == context + + transformer.validate_and_initialize() + + model_fn.assert_called_once_with(environment.model_dir, context) + env.assert_called_once_with() + validate_user_module.assert_called_once_with() + + +@patch("sagemaker_inference.transformer.Transformer._validate_user_module_and_set_functions") +@patch("sagemaker_inference.environment.Environment") +def test_handle_validate_and_initialize_error(env, validate_user_module): + data = [{"body": INPUT_DATA}] + request_processor = Mock() + + context = Mock() + context.request_processor = [request_processor] + + transform_fn = Mock() + model_fn = Mock() + + transformer = PyTorchTransformer() + + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._model_fn = model_fn + transformer._context = context + + test_error_message = "Foo" + validate_user_module.side_effect = ValueError(test_error_message) + + assert transformer._initialized is False + + response = transformer.transform(data, context) + assert test_error_message in str(response) + assert "Traceback (most recent call last)" in str(response) + context.set_response_status.assert_called_with( + code=http_client.INTERNAL_SERVER_ERROR, phrase=test_error_message + ) + + +@patch("sagemaker_inference.transformer.Transformer._validate_user_module_and_set_functions") +@patch("sagemaker_inference.environment.Environment") +def test_handle_validate_and_initialize_user_error(env, validate_user_module): + test_status_code = http_client.FORBIDDEN + test_error_message = "Foo" + + class FooUserError(BaseInferenceToolkitError): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.phrase = "Foo" + + data = [{"body": INPUT_DATA}] + context = Mock() + transform_fn = Mock() + model_fn = Mock() + + transformer = PyTorchTransformer() + + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._model_fn = model_fn + transformer._context = context + + validate_user_module.side_effect = FooUserError(test_status_code, test_error_message) + + assert transformer._initialized is False + + response = transformer.transform(data, context) + assert test_error_message in str(response) + assert "Traceback (most recent call last)" in str(response) + context.set_response_status.assert_called_with( + code=http_client.FORBIDDEN, phrase=test_error_message + ) + + +def test_default_transform_fn(): + transformer = PyTorchTransformer() + context = Mock() + transformer._context = context + + input_fn = Mock(return_value=PREPROCESSED_DATA) + predict_fn = Mock(return_value=PREDICT_RESULT) + output_fn = Mock(return_value=PROCESSED_RESULT) + + transformer._input_fn = input_fn + transformer._predict_fn = predict_fn + transformer._output_fn = output_fn + + result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT) + + input_fn.assert_called_once_with(INPUT_DATA, CONTENT_TYPE, context) + predict_fn.assert_called_once_with(PREPROCESSED_DATA, MODEL, context) + output_fn.assert_called_once_with(PREDICT_RESULT, ACCEPT, context) + assert result == PROCESSED_RESULT + + +def test_run_handle_function(): + def three_inputs_func(a, b, c): + pass + + three_inputs_mock = Mock(spec=three_inputs_func) + a = Mock() + b = Mock() + context = Mock() + + transformer = PyTorchTransformer() + transformer._context = context + transformer._run_handle_function(three_inputs_mock, a, b) + three_inputs_mock.assert_called_with(a, b, context) diff --git a/test/utils/file_utils.py b/test/utils/file_utils.py index 8cc3771d..327c8ecc 100644 --- a/test/utils/file_utils.py +++ b/test/utils/file_utils.py @@ -19,7 +19,7 @@ def make_tarfile(script, model, output_path, filename="model.tar.gz", script_path=None): output_filename = os.path.join(output_path, filename) with tarfile.open(output_filename, "w:gz") as tar: - if(script_path): + if (script_path): tar.add(script, arcname=os.path.join(script_path, os.path.basename(script))) else: tar.add(script, arcname=os.path.basename(script)) diff --git a/tox.ini b/tox.ini index 3ee46f4e..ee9acbc0 100644 --- a/tox.ini +++ b/tox.ini @@ -70,6 +70,7 @@ deps = six future pyyaml + protobuf <= 3.20.1 #https://exerror.com/typeerror-descriptors-cannot-not-be-created-directly/ [testenv:flake8] basepython = python3