Skip to content

Commit 4cd65a5

Browse files
pravali96pintaoz-aws
authored andcommitted
feat: add pre-processing and post-processing logic to inference_spec (#1560)
* add pre-processing and post-processing logic to inference_spec * fix format * make accept_type and content_type optional * remove accept_type and content_type from pre/post processing * correct typo
1 parent c1bc587 commit 4cd65a5

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

src/sagemaker/serve/model_server/multi_model_server/inference.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,19 @@ def input_fn(input_data, content_type):
4444
"""Deserializes the bytes that were received from the model server"""
4545
try:
4646
if hasattr(schema_builder, "custom_input_translator"):
47-
return schema_builder.custom_input_translator.deserialize(
47+
deserialized_data = schema_builder.custom_input_translator.deserialize(
4848
io.BytesIO(input_data), content_type
4949
)
5050
else:
51-
return schema_builder.input_deserializer.deserialize(
51+
deserialized_data = schema_builder.input_deserializer.deserialize(
5252
io.BytesIO(input_data), content_type[0]
5353
)
54+
55+
# Check if preprocess method is defined and call it
56+
if hasattr(inference_spec, "preprocess"):
57+
return inference_spec.preprocess(deserialized_data)
58+
59+
return deserialized_data
5460
except Exception as e:
5561
logger.error("Encountered error: %s in deserialize_response." % e)
5662
raise Exception("Encountered error in deserialize_request.") from e
@@ -64,6 +70,8 @@ def predict_fn(input_data, predict_callable):
6470
def output_fn(predictions, accept_type):
6571
"""Prediction is serialized to bytes and sent back to the customer"""
6672
try:
73+
if hasattr(inference_spec, "postprocess"):
74+
predictions = inference_spec.postprocess(predictions)
6775
if hasattr(schema_builder, "custom_output_translator"):
6876
return schema_builder.custom_output_translator.serialize(predictions, accept_type)
6977
else:

src/sagemaker/serve/model_server/torchserve/inference.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,19 @@ def input_fn(input_data, content_type):
6666
"""Placeholder docstring"""
6767
try:
6868
if hasattr(schema_builder, "custom_input_translator"):
69-
return schema_builder.custom_input_translator.deserialize(
69+
deserialized_data = schema_builder.custom_input_translator.deserialize(
7070
io.BytesIO(input_data), content_type
7171
)
7272
else:
73-
return schema_builder.input_deserializer.deserialize(
73+
deserialized_data = schema_builder.input_deserializer.deserialize(
7474
io.BytesIO(input_data), content_type[0]
7575
)
76+
77+
# Check if preprocess method is defined and call it
78+
if hasattr(inference_spec, "preprocess"):
79+
return inference_spec.preprocess(deserialized_data)
80+
81+
return deserialized_data
7682
except Exception as e:
7783
raise Exception("Encountered error in deserialize_request.") from e
7884

@@ -85,6 +91,8 @@ def predict_fn(input_data, predict_callable):
8591
def output_fn(predictions, accept_type):
8692
"""Placeholder docstring"""
8793
try:
94+
if hasattr(inference_spec, "postprocess"):
95+
predictions = inference_spec.postprocess(predictions)
8896
if hasattr(schema_builder, "custom_output_translator"):
8997
return schema_builder.custom_output_translator.serialize(predictions, accept_type)
9098
else:

src/sagemaker/serve/spec/inference_spec.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def invoke(self, input_object: object, model: object):
2828
model (object): The model object
2929
"""
3030

31+
def preprocess(self, input_data: object):
32+
"""Custom pre-processing function"""
33+
34+
def postprocess(self, predictions: object):
35+
"""Custom post-processing function"""
36+
3137
def prepare(self, *args, **kwargs):
3238
"""Custom prepare function"""
3339

0 commit comments

Comments
 (0)