21
21
from sagemaker_inference import content_types , environment , utils
22
22
from sagemaker_inference .errors import BaseInferenceToolkitError , GenericInferenceToolkitError
23
23
24
- logger = logging .getLogger ()
25
24
26
25
class PTTransformer (Transformer ):
27
26
"""Represents the execution workflow for handling pytorch inference requests
@@ -46,7 +45,7 @@ def transform(self, data, context):
46
45
try :
47
46
properties = context .system_properties
48
47
model_dir = properties .get ("model_dir" )
49
- self .validate_and_initialize (model_dir = model_dir , cotext = self ._context )
48
+ self .validate_and_initialize (model_dir = model_dir , context = self ._context )
50
49
51
50
input_data = data [0 ].get ("body" )
52
51
@@ -62,14 +61,7 @@ def transform(self, data, context):
62
61
if content_type in content_types .UTF8_TYPES :
63
62
input_data = input_data .decode ("utf-8" )
64
63
65
- try :
66
- # custom/default handler takes context (for multi-gpu setup)
67
- logger .info ('running transform function with context.' )
68
- result = self ._transform_fn (self ._model , input_data , content_type , accept , self ._context )
69
- except TypeError :
70
- # custom handler does not take context
71
- logger .info ('running transform function without context.' )
72
- result = self ._transform_fn (self ._model , input_data , content_type , accept )
64
+ result = self ._run_handle_function (self ._transform_fn , * (self ._model , input_data , content_type , accept ))
73
65
74
66
response = result
75
67
response_content_type = accept
@@ -100,22 +92,15 @@ def validate_and_initialize(self, model_dir=environment.model_dir, context=None)
100
92
self ._context = context
101
93
self ._environment = environment .Environment ()
102
94
self ._validate_user_module_and_set_functions ()
103
- try :
104
- # custom/default model function takes context (for multi-gpu setup)
105
- logger .info ('running model functions with context.' )
106
- if self ._pre_model_fn is not None :
107
- self ._pre_model_fn (model_dir , context )
108
- self ._model = self ._model_fn (model_dir , context )
109
- if self ._model_warmup_fn is not None :
110
- self ._model_warmup_fn (model_dir , self ._model , context )
111
- except TypeError :
112
- # custom model function does not take context
113
- logger .info ('running model functions without context.' )
114
- if self ._pre_model_fn is not None :
115
- self ._pre_model_fn (model_dir )
116
- self ._model = self ._model_fn (model_dir )
117
- if self ._model_warmup_fn is not None :
118
- self ._model_warmup_fn (model_dir , self ._model )
95
+
96
+ if self ._pre_model_fn is not None :
97
+ self ._run_handle_function (self ._pre_model_fn , * (model_dir , ))
98
+
99
+ self ._model = self ._run_handle_function (self ._model_fn , * (model_dir , ))
100
+
101
+ if self ._model_warmup_fn is not None :
102
+ self ._run_handle_function (self ._model_warmup_fn , * (model_dir , self ._model ))
103
+
119
104
self ._initialized = True
120
105
121
106
def _default_transform_fn (self , model , input_data , content_type , accept ):
@@ -131,18 +116,21 @@ def _default_transform_fn(self, model, input_data, content_type, accept):
131
116
obj: the serialized prediction result or a tuple of the form
132
117
(response_data, content_type)
133
118
"""
134
- try :
135
- # custom/default handler takes context (for multi-gpu setup)
136
- logger .info ('running handler functions with context.' )
137
- data = self ._input_fn (input_data , content_type , self ._context )
138
- prediction = self ._predict_fn (data , model , self ._context )
139
- result = self ._output_fn (prediction , accept , self ._context )
140
- except TypeError :
141
- # custom handler does not take context
142
- logger .info ('running handler functions without context.' )
143
- data = self ._input_fn (input_data , content_type )
144
- prediction = self ._predict_fn (data , model )
145
- result = self ._output_fn (prediction , accept )
119
+ data = self ._run_handle_function (self ._input_fn , * (input_data , content_type ))
120
+ prediction = self ._run_handle_function (self ._predict_fn , * (data , model ))
121
+ result = self ._run_handle_function (self ._output_fn , * (prediction , accept ))
146
122
147
123
return result
148
-
124
+
125
+ def _run_handle_function (self , func , * argv ):
126
+ """Wrapper to call the handle function which covers 2 cases:
127
+ 1. context passed to the handle function
128
+ 2. context not passed to the handle function
129
+ """
130
+ try :
131
+ argv_context = argv + (self ._context , )
132
+ result = func (* argv_context )
133
+ except TypeError :
134
+ result = func (* argv )
135
+
136
+ return result
0 commit comments