Skip to content

Commit 47d6f21

Browse files
author
Wei Chu
committed
add fucntion wrapper
1 parent 1f1ea24 commit 47d6f21

File tree

1 file changed

+27
-39
lines changed

1 file changed

+27
-39
lines changed

src/sagemaker_pytorch_serving_container/transformer.py

Lines changed: 27 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from sagemaker_inference import content_types, environment, utils
2222
from sagemaker_inference.errors import BaseInferenceToolkitError, GenericInferenceToolkitError
2323

24-
logger = logging.getLogger()
2524

2625
class PTTransformer(Transformer):
2726
"""Represents the execution workflow for handling pytorch inference requests
@@ -46,7 +45,7 @@ def transform(self, data, context):
4645
try:
4746
properties = context.system_properties
4847
model_dir = properties.get("model_dir")
49-
self.validate_and_initialize(model_dir=model_dir, cotext=self._context)
48+
self.validate_and_initialize(model_dir=model_dir, context=self._context)
5049

5150
input_data = data[0].get("body")
5251

@@ -62,14 +61,7 @@ def transform(self, data, context):
6261
if content_type in content_types.UTF8_TYPES:
6362
input_data = input_data.decode("utf-8")
6463

65-
try:
66-
# custom/default handler takes context (for multi-gpu setup)
67-
logger.info('running transform function with context.')
68-
result = self._transform_fn(self._model, input_data, content_type, accept, self._context)
69-
except TypeError:
70-
# custom handler does not take context
71-
logger.info('running transform function without context.')
72-
result = self._transform_fn(self._model, input_data, content_type, accept)
64+
result = self._run_handle_function(self._transform_fn, *(self._model, input_data, content_type, accept))
7365

7466
response = result
7567
response_content_type = accept
@@ -100,22 +92,15 @@ def validate_and_initialize(self, model_dir=environment.model_dir, context=None)
10092
self._context = context
10193
self._environment = environment.Environment()
10294
self._validate_user_module_and_set_functions()
103-
try:
104-
# custom/default model function takes context (for multi-gpu setup)
105-
logger.info('running model functions with context.')
106-
if self._pre_model_fn is not None:
107-
self._pre_model_fn(model_dir, context)
108-
self._model = self._model_fn(model_dir, context)
109-
if self._model_warmup_fn is not None:
110-
self._model_warmup_fn(model_dir, self._model, context)
111-
except TypeError:
112-
# custom model function does not take context
113-
logger.info('running model functions without context.')
114-
if self._pre_model_fn is not None:
115-
self._pre_model_fn(model_dir)
116-
self._model = self._model_fn(model_dir)
117-
if self._model_warmup_fn is not None:
118-
self._model_warmup_fn(model_dir, self._model)
95+
96+
if self._pre_model_fn is not None:
97+
self._run_handle_function(self._pre_model_fn, *(model_dir, ))
98+
99+
self._model = self._run_handle_function(self._model_fn, *(model_dir, ))
100+
101+
if self._model_warmup_fn is not None:
102+
self._run_handle_function(self._model_warmup_fn, *(model_dir, self._model))
103+
119104
self._initialized = True
120105

121106
def _default_transform_fn(self, model, input_data, content_type, accept):
@@ -131,18 +116,21 @@ def _default_transform_fn(self, model, input_data, content_type, accept):
131116
obj: the serialized prediction result or a tuple of the form
132117
(response_data, content_type)
133118
"""
134-
try:
135-
# custom/default handler takes context (for multi-gpu setup)
136-
logger.info('running handler functions with context.')
137-
data = self._input_fn(input_data, content_type, self._context)
138-
prediction = self._predict_fn(data, model, self._context)
139-
result = self._output_fn(prediction, accept, self._context)
140-
except TypeError:
141-
# custom handler does not take context
142-
logger.info('running handler functions without context.')
143-
data = self._input_fn(input_data, content_type)
144-
prediction = self._predict_fn(data, model)
145-
result = self._output_fn(prediction, accept)
119+
data = self._run_handle_function(self._input_fn, *(input_data, content_type))
120+
prediction = self._run_handle_function(self._predict_fn, *(data, model))
121+
result = self._run_handle_function(self._output_fn, *(prediction, accept))
146122

147123
return result
148-
124+
125+
def _run_handle_function(self, func, *argv):
126+
"""Wrapper to call the handle function which covers 2 cases:
127+
1. context passed to the handle function
128+
2. context not passed to the handle function
129+
"""
130+
try:
131+
argv_context = argv + (self._context, )
132+
result = func(*argv_context)
133+
except TypeError:
134+
result = func(*argv)
135+
136+
return result

0 commit comments

Comments
 (0)