From fc8e793960f16bd3449b42e5cec44f1211ddacc3 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Wed, 5 Aug 2020 14:53:57 -0500 Subject: [PATCH] Update default_inference_handler.py --- .../default_pytorch_inference_handler.py | 21 ++++++++++++++----- test/unit/test_default_inference_handler.py | 7 +++++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py b/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py index f2533709..92857434 100644 --- a/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py +++ b/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py @@ -16,7 +16,14 @@ import textwrap import torch -from sagemaker_inference import content_types, decoder, default_inference_handler, encoder +from sagemaker_inference import ( + content_types, + decoder, + default_inference_handler, + encoder, + errors, + utils, +) INFERENCE_ACCELERATOR_PRESENT_ENV = "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT" DEFAULT_MODEL_FILENAME = "model.pt" @@ -101,8 +108,12 @@ def default_output_fn(self, prediction, accept): """ if type(prediction) == torch.Tensor: prediction = prediction.detach().cpu().numpy().tolist() - encoded_prediction = encoder.encode(prediction, accept) - if accept == content_types.CSV: - encoded_prediction = encoded_prediction.encode("utf-8") - return encoded_prediction + for content_type in utils.parse_accept(accept): + if content_type in encoder.SUPPORTED_CONTENT_TYPES: + encoded_prediction = encoder.encode(prediction, content_type) + if content_type == content_types.CSV: + encoded_prediction = encoded_prediction.encode("utf-8") + return encoded_prediction + + raise errors.UnsupportedFormatError(accept) diff --git a/test/unit/test_default_inference_handler.py b/test/unit/test_default_inference_handler.py index b8a37449..f45fadde 100644 --- a/test/unit/test_default_inference_handler.py +++ b/test/unit/test_default_inference_handler.py @@ -166,6 +166,13 @@ def test_default_output_fn_csv_float(inference_handler): assert '1.0,2.0,3.0\n4.0,5.0,6.0\n'.encode("utf-8") == output +def test_default_output_fn_multiple_content_types(inference_handler, tensor): + accept = ", ".join(["application/unsupported", content_types.JSON, content_types.CSV]) + output = inference_handler.default_output_fn(tensor, accept) + + assert json.dumps(tensor.cpu().numpy().tolist()) == output + + def test_default_output_fn_bad_accept(inference_handler): with pytest.raises(errors.UnsupportedFormatError): inference_handler.default_output_fn("", "application/not_supported")