Skip to content

Commit 0d293f4

Browse files
author
Wei Chu
committed
add test
1 parent 2d6ce94 commit 0d293f4

File tree

1 file changed

+151
-1
lines changed

1 file changed

+151
-1
lines changed

test/unit/test_transformer.py

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,21 @@
1616
from mock import Mock, patch
1717
import pytest
1818

19-
from sagemaker_inference import environment
19+
try:
20+
import http.client as http_client
21+
except ImportError:
22+
import httplib as http_client
23+
24+
from sagemaker_inference import content_types, environment
2025
from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler
26+
from sagemaker_inference.errors import BaseInferenceToolkitError
2127
from sagemaker_pytorch_serving_container.transformer import PyTorchTransformer
2228

2329

2430
INPUT_DATA = "input_data"
2531
CONTENT_TYPE = "content_type"
2632
ACCEPT = "accept"
33+
DEFAULT_ACCEPT = "default_accept"
2734
RESULT = "result"
2835
MODEL = "foo"
2936

@@ -96,6 +103,81 @@ def test_transform(validate, retrieve_content_type_header, accept_key):
96103
assert result[0] == RESULT
97104

98105

106+
@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE)
107+
@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize")
108+
def test_transform_any_accept(validate, retrieve_content_type_header):
109+
data = [{"body": INPUT_DATA}]
110+
context = Mock()
111+
request_processor = Mock()
112+
transform_fn = Mock()
113+
environment = Mock()
114+
environment.default_accept = DEFAULT_ACCEPT
115+
116+
context.request_processor = [request_processor]
117+
request_processor.get_request_properties.return_value = {"accept": content_types.ANY}
118+
119+
transformer = PyTorchTransformer()
120+
transformer._model = MODEL
121+
transformer._transform_fn = transform_fn
122+
transformer._environment = environment
123+
transformer._context = context
124+
125+
transformer.transform(data, context)
126+
127+
validate.assert_called_once()
128+
transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, DEFAULT_ACCEPT, context)
129+
130+
131+
@pytest.mark.parametrize("content_type", content_types.UTF8_TYPES)
132+
@patch("sagemaker_inference.utils.retrieve_content_type_header")
133+
@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize")
134+
def test_transform_decode(validate, retrieve_content_type_header, content_type):
135+
input_data = Mock()
136+
context = Mock()
137+
request_processor = Mock()
138+
transform_fn = Mock()
139+
data = [{"body": input_data}]
140+
141+
input_data.decode.return_value = INPUT_DATA
142+
context.request_processor = [request_processor]
143+
request_processor.get_request_properties.return_value = {"accept": ACCEPT}
144+
retrieve_content_type_header.return_value = content_type
145+
146+
transformer = PyTorchTransformer()
147+
transformer._model = MODEL
148+
transformer._transform_fn = transform_fn
149+
transformer._context = context
150+
151+
transformer.transform(data, context)
152+
153+
input_data.decode.assert_called_once_with("utf-8")
154+
transform_fn.assert_called_once_with(MODEL, INPUT_DATA, content_type, ACCEPT, context)
155+
156+
157+
@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE)
158+
@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize")
159+
def test_transform_tuple(validate, retrieve_content_type_header):
160+
data = [{"body": INPUT_DATA}]
161+
context = Mock()
162+
request_processor = Mock()
163+
transform_fn = Mock(return_value=(RESULT, ACCEPT))
164+
165+
context.request_processor = [request_processor]
166+
request_processor.get_request_properties.return_value = {"accept": ACCEPT}
167+
168+
transformer = PyTorchTransformer()
169+
transformer._model = MODEL
170+
transformer._transform_fn = transform_fn
171+
transformer._context = context
172+
173+
result = transformer.transform(data, context)
174+
175+
transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, context)
176+
context.set_response_content_type.assert_called_once_with(0, transform_fn()[1])
177+
assert isinstance(result, list)
178+
assert result[0] == transform_fn()[0]
179+
180+
99181
@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer._validate_user_module_and_set_functions")
100182
@patch("sagemaker_inference.environment.Environment")
101183
def test_validate_and_initialize(env, validate_user_module):
@@ -120,6 +202,74 @@ def test_validate_and_initialize(env, validate_user_module):
120202
validate_user_module.assert_called_once_with()
121203

122204

205+
@patch("sagemaker_inference.transformer.Transformer._validate_user_module_and_set_functions")
206+
@patch("sagemaker_inference.environment.Environment")
207+
def test_handle_validate_and_initialize_error(env, validate_user_module):
208+
data = [{"body": INPUT_DATA}]
209+
request_processor = Mock()
210+
211+
context = Mock()
212+
context.request_processor = [request_processor]
213+
214+
transform_fn = Mock()
215+
model_fn = Mock()
216+
217+
transformer = PyTorchTransformer()
218+
219+
transformer._model = MODEL
220+
transformer._transform_fn = transform_fn
221+
transformer._model_fn = model_fn
222+
transformer._context = context
223+
224+
test_error_message = "Foo"
225+
validate_user_module.side_effect = ValueError(test_error_message)
226+
227+
assert transformer._initialized is False
228+
229+
response = transformer.transform(data, context)
230+
assert test_error_message in str(response)
231+
assert "Traceback (most recent call last)" in str(response)
232+
context.set_response_status.assert_called_with(
233+
code=http_client.INTERNAL_SERVER_ERROR, phrase=test_error_message
234+
)
235+
236+
237+
@patch("sagemaker_inference.transformer.Transformer._validate_user_module_and_set_functions")
238+
@patch("sagemaker_inference.environment.Environment")
239+
def test_handle_validate_and_initialize_user_error(env, validate_user_module):
240+
test_status_code = http_client.FORBIDDEN
241+
test_error_message = "Foo"
242+
243+
class FooUserError(BaseInferenceToolkitError):
244+
def __init__(self, status_code, message):
245+
self.status_code = status_code
246+
self.message = message
247+
self.phrase = "Foo"
248+
249+
data = [{"body": INPUT_DATA}]
250+
context = Mock()
251+
transform_fn = Mock()
252+
model_fn = Mock()
253+
254+
transformer = PyTorchTransformer()
255+
256+
transformer._model = MODEL
257+
transformer._transform_fn = transform_fn
258+
transformer._model_fn = model_fn
259+
transformer._context = context
260+
261+
validate_user_module.side_effect = FooUserError(test_status_code, test_error_message)
262+
263+
assert transformer._initialized is False
264+
265+
response = transformer.transform(data, context)
266+
assert test_error_message in str(response)
267+
assert "Traceback (most recent call last)" in str(response)
268+
context.set_response_status.assert_called_with(
269+
code=http_client.FORBIDDEN, phrase=test_error_message
270+
)
271+
272+
123273
def test_default_transform_fn():
124274
transformer = PyTorchTransformer()
125275
context = Mock()

0 commit comments

Comments
 (0)