Skip to content

change: Add support for Accept headers with multiple MIME types #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@
import textwrap

import torch
from sagemaker_inference import content_types, decoder, default_inference_handler, encoder
from sagemaker_inference import (
content_types,
decoder,
default_inference_handler,
encoder,
errors,
utils,
)

INFERENCE_ACCELERATOR_PRESENT_ENV = "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT"
DEFAULT_MODEL_FILENAME = "model.pt"
Expand Down Expand Up @@ -101,8 +108,12 @@ def default_output_fn(self, prediction, accept):
"""
if type(prediction) == torch.Tensor:
prediction = prediction.detach().cpu().numpy().tolist()
encoded_prediction = encoder.encode(prediction, accept)
if accept == content_types.CSV:
encoded_prediction = encoded_prediction.encode("utf-8")

return encoded_prediction
for content_type in utils.parse_accept(accept):
if content_type in encoder.SUPPORTED_CONTENT_TYPES:
encoded_prediction = encoder.encode(prediction, content_type)
if content_type == content_types.CSV:
encoded_prediction = encoded_prediction.encode("utf-8")
Comment on lines +115 to +116
Copy link
Contributor Author

@bveeramani bveeramani Aug 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if the encode("utf-8") call is actually necessary. If the prediction is encoded as JSON, then a string is returned. I figure that the server must automatically encode strings responses as UTF-8.

In any case, the encode("utf-8") was there before, so I've decided to keep it just to be safe.

return encoded_prediction

raise errors.UnsupportedFormatError(accept)
7 changes: 7 additions & 0 deletions test/unit/test_default_inference_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ def test_default_output_fn_csv_float(inference_handler):
assert '1.0,2.0,3.0\n4.0,5.0,6.0\n'.encode("utf-8") == output


def test_default_output_fn_multiple_content_types(inference_handler, tensor):
accept = ", ".join(["application/unsupported", content_types.JSON, content_types.CSV])
output = inference_handler.default_output_fn(tensor, accept)

assert json.dumps(tensor.cpu().numpy().tolist()) == output


def test_default_output_fn_bad_accept(inference_handler):
with pytest.raises(errors.UnsupportedFormatError):
inference_handler.default_output_fn("", "application/not_supported")
Expand Down