Skip to content

Commit fc8e793

Browse files
author
Balaji Veeramani
committed
Update default_inference_handler.py
1 parent 791985d commit fc8e793

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@
1616
import textwrap
1717

1818
import torch
19-
from sagemaker_inference import content_types, decoder, default_inference_handler, encoder
19+
from sagemaker_inference import (
20+
content_types,
21+
decoder,
22+
default_inference_handler,
23+
encoder,
24+
errors,
25+
utils,
26+
)
2027

2128
INFERENCE_ACCELERATOR_PRESENT_ENV = "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT"
2229
DEFAULT_MODEL_FILENAME = "model.pt"
@@ -101,8 +108,12 @@ def default_output_fn(self, prediction, accept):
101108
"""
102109
if type(prediction) == torch.Tensor:
103110
prediction = prediction.detach().cpu().numpy().tolist()
104-
encoded_prediction = encoder.encode(prediction, accept)
105-
if accept == content_types.CSV:
106-
encoded_prediction = encoded_prediction.encode("utf-8")
107111

108-
return encoded_prediction
112+
for content_type in utils.parse_accept(accept):
113+
if content_type in encoder.SUPPORTED_CONTENT_TYPES:
114+
encoded_prediction = encoder.encode(prediction, content_type)
115+
if content_type == content_types.CSV:
116+
encoded_prediction = encoded_prediction.encode("utf-8")
117+
return encoded_prediction
118+
119+
raise errors.UnsupportedFormatError(accept)

test/unit/test_default_inference_handler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,13 @@ def test_default_output_fn_csv_float(inference_handler):
166166
assert '1.0,2.0,3.0\n4.0,5.0,6.0\n'.encode("utf-8") == output
167167

168168

169+
def test_default_output_fn_multiple_content_types(inference_handler, tensor):
170+
accept = ", ".join(["application/unsupported", content_types.JSON, content_types.CSV])
171+
output = inference_handler.default_output_fn(tensor, accept)
172+
173+
assert json.dumps(tensor.cpu().numpy().tolist()) == output
174+
175+
169176
def test_default_output_fn_bad_accept(inference_handler):
170177
with pytest.raises(errors.UnsupportedFormatError):
171178
inference_handler.default_output_fn("", "application/not_supported")

0 commit comments

Comments
 (0)