16
16
from mock import Mock , patch
17
17
import pytest
18
18
19
- from sagemaker_inference import environment
19
+ try :
20
+ import http .client as http_client
21
+ except ImportError :
22
+ import httplib as http_client
23
+
24
+ from sagemaker_inference import content_types , environment
20
25
from sagemaker_pytorch_serving_container .default_pytorch_inference_handler import DefaultPytorchInferenceHandler
26
+ from sagemaker_inference .errors import BaseInferenceToolkitError
21
27
from sagemaker_pytorch_serving_container .transformer import PyTorchTransformer
22
28
23
29
24
30
INPUT_DATA = "input_data"
25
31
CONTENT_TYPE = "content_type"
26
32
ACCEPT = "accept"
33
+ DEFAULT_ACCEPT = "default_accept"
27
34
RESULT = "result"
28
35
MODEL = "foo"
29
36
@@ -96,6 +103,81 @@ def test_transform(validate, retrieve_content_type_header, accept_key):
96
103
assert result [0 ] == RESULT
97
104
98
105
106
+ @patch ("sagemaker_inference.utils.retrieve_content_type_header" , return_value = CONTENT_TYPE )
107
+ @patch ("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize" )
108
+ def test_transform_any_accept (validate , retrieve_content_type_header ):
109
+ data = [{"body" : INPUT_DATA }]
110
+ context = Mock ()
111
+ request_processor = Mock ()
112
+ transform_fn = Mock ()
113
+ environment = Mock ()
114
+ environment .default_accept = DEFAULT_ACCEPT
115
+
116
+ context .request_processor = [request_processor ]
117
+ request_processor .get_request_properties .return_value = {"accept" : content_types .ANY }
118
+
119
+ transformer = PyTorchTransformer ()
120
+ transformer ._model = MODEL
121
+ transformer ._transform_fn = transform_fn
122
+ transformer ._environment = environment
123
+ transformer ._context = context
124
+
125
+ transformer .transform (data , context )
126
+
127
+ validate .assert_called_once ()
128
+ transform_fn .assert_called_once_with (MODEL , INPUT_DATA , CONTENT_TYPE , DEFAULT_ACCEPT , context )
129
+
130
+
131
+ @pytest .mark .parametrize ("content_type" , content_types .UTF8_TYPES )
132
+ @patch ("sagemaker_inference.utils.retrieve_content_type_header" )
133
+ @patch ("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize" )
134
+ def test_transform_decode (validate , retrieve_content_type_header , content_type ):
135
+ input_data = Mock ()
136
+ context = Mock ()
137
+ request_processor = Mock ()
138
+ transform_fn = Mock ()
139
+ data = [{"body" : input_data }]
140
+
141
+ input_data .decode .return_value = INPUT_DATA
142
+ context .request_processor = [request_processor ]
143
+ request_processor .get_request_properties .return_value = {"accept" : ACCEPT }
144
+ retrieve_content_type_header .return_value = content_type
145
+
146
+ transformer = PyTorchTransformer ()
147
+ transformer ._model = MODEL
148
+ transformer ._transform_fn = transform_fn
149
+ transformer ._context = context
150
+
151
+ transformer .transform (data , context )
152
+
153
+ input_data .decode .assert_called_once_with ("utf-8" )
154
+ transform_fn .assert_called_once_with (MODEL , INPUT_DATA , content_type , ACCEPT , context )
155
+
156
+
157
+ @patch ("sagemaker_inference.utils.retrieve_content_type_header" , return_value = CONTENT_TYPE )
158
+ @patch ("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize" )
159
+ def test_transform_tuple (validate , retrieve_content_type_header ):
160
+ data = [{"body" : INPUT_DATA }]
161
+ context = Mock ()
162
+ request_processor = Mock ()
163
+ transform_fn = Mock (return_value = (RESULT , ACCEPT ))
164
+
165
+ context .request_processor = [request_processor ]
166
+ request_processor .get_request_properties .return_value = {"accept" : ACCEPT }
167
+
168
+ transformer = PyTorchTransformer ()
169
+ transformer ._model = MODEL
170
+ transformer ._transform_fn = transform_fn
171
+ transformer ._context = context
172
+
173
+ result = transformer .transform (data , context )
174
+
175
+ transform_fn .assert_called_once_with (MODEL , INPUT_DATA , CONTENT_TYPE , ACCEPT , context )
176
+ context .set_response_content_type .assert_called_once_with (0 , transform_fn ()[1 ])
177
+ assert isinstance (result , list )
178
+ assert result [0 ] == transform_fn ()[0 ]
179
+
180
+
99
181
@patch ("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer._validate_user_module_and_set_functions" )
100
182
@patch ("sagemaker_inference.environment.Environment" )
101
183
def test_validate_and_initialize (env , validate_user_module ):
@@ -120,6 +202,74 @@ def test_validate_and_initialize(env, validate_user_module):
120
202
validate_user_module .assert_called_once_with ()
121
203
122
204
205
+ @patch ("sagemaker_inference.transformer.Transformer._validate_user_module_and_set_functions" )
206
+ @patch ("sagemaker_inference.environment.Environment" )
207
+ def test_handle_validate_and_initialize_error (env , validate_user_module ):
208
+ data = [{"body" : INPUT_DATA }]
209
+ request_processor = Mock ()
210
+
211
+ context = Mock ()
212
+ context .request_processor = [request_processor ]
213
+
214
+ transform_fn = Mock ()
215
+ model_fn = Mock ()
216
+
217
+ transformer = PyTorchTransformer ()
218
+
219
+ transformer ._model = MODEL
220
+ transformer ._transform_fn = transform_fn
221
+ transformer ._model_fn = model_fn
222
+ transformer ._context = context
223
+
224
+ test_error_message = "Foo"
225
+ validate_user_module .side_effect = ValueError (test_error_message )
226
+
227
+ assert transformer ._initialized is False
228
+
229
+ response = transformer .transform (data , context )
230
+ assert test_error_message in str (response )
231
+ assert "Traceback (most recent call last)" in str (response )
232
+ context .set_response_status .assert_called_with (
233
+ code = http_client .INTERNAL_SERVER_ERROR , phrase = test_error_message
234
+ )
235
+
236
+
237
+ @patch ("sagemaker_inference.transformer.Transformer._validate_user_module_and_set_functions" )
238
+ @patch ("sagemaker_inference.environment.Environment" )
239
+ def test_handle_validate_and_initialize_user_error (env , validate_user_module ):
240
+ test_status_code = http_client .FORBIDDEN
241
+ test_error_message = "Foo"
242
+
243
+ class FooUserError (BaseInferenceToolkitError ):
244
+ def __init__ (self , status_code , message ):
245
+ self .status_code = status_code
246
+ self .message = message
247
+ self .phrase = "Foo"
248
+
249
+ data = [{"body" : INPUT_DATA }]
250
+ context = Mock ()
251
+ transform_fn = Mock ()
252
+ model_fn = Mock ()
253
+
254
+ transformer = PyTorchTransformer ()
255
+
256
+ transformer ._model = MODEL
257
+ transformer ._transform_fn = transform_fn
258
+ transformer ._model_fn = model_fn
259
+ transformer ._context = context
260
+
261
+ validate_user_module .side_effect = FooUserError (test_status_code , test_error_message )
262
+
263
+ assert transformer ._initialized is False
264
+
265
+ response = transformer .transform (data , context )
266
+ assert test_error_message in str (response )
267
+ assert "Traceback (most recent call last)" in str (response )
268
+ context .set_response_status .assert_called_with (
269
+ code = http_client .FORBIDDEN , phrase = test_error_message
270
+ )
271
+
272
+
123
273
def test_default_transform_fn ():
124
274
transformer = PyTorchTransformer ()
125
275
context = Mock ()
0 commit comments