Skip to content

Commit fa53af0

Browse files
authored
breaking: Move _JsonSerializer to sagemaker.serializers.JSONSerializer (#1698)
1 parent cc2d047 commit fa53af0

File tree

9 files changed

+124
-110
lines changed

9 files changed

+124
-110
lines changed

doc/frameworks/tensorflow/upgrade_from_legacy.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,10 @@ For example, if you want to use JSON serialization and deserialization:
245245

246246
.. code:: python
247247
248-
from sagemaker.predictor import json_deserializer, json_serializer
248+
from sagemaker.predictor import json_deserializer
249+
from sagemaker.serializers import JSONSerializer
249250
250-
predictor.content_type = "application/json"
251-
predictor.serializer = json_serializer
251+
predictor.serializer = JSONSerializer()
252252
predictor.accept = "application/json"
253253
predictor.deserializer = json_deserializer
254254

src/sagemaker/mxnet/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
)
2727
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2828
from sagemaker.mxnet import defaults
29-
from sagemaker.predictor import Predictor, json_serializer, json_deserializer
29+
from sagemaker.predictor import Predictor, json_deserializer
30+
from sagemaker.serializers import JSONSerializer
3031

3132
logger = logging.getLogger("sagemaker")
3233

@@ -50,7 +51,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
5051
using the default AWS configuration chain.
5152
"""
5253
super(MXNetPredictor, self).__init__(
53-
endpoint_name, sagemaker_session, json_serializer, json_deserializer
54+
endpoint_name, sagemaker_session, JSONSerializer(), json_deserializer
5455
)
5556

5657

src/sagemaker/predictor.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import codecs
1717
import csv
1818
import json
19-
import six
2019
from six import StringIO, BytesIO
2120
import numpy as np
2221

@@ -597,55 +596,6 @@ def _row_to_csv(obj):
597596
return ",".join(obj)
598597

599598

600-
class _JsonSerializer(object):
601-
"""Placeholder docstring"""
602-
603-
def __init__(self):
604-
"""Placeholder docstring"""
605-
self.content_type = CONTENT_TYPE_JSON
606-
607-
def __call__(self, data):
608-
"""Take data of various formats and serialize them into the expected
609-
request body. This uses information about supported input formats for
610-
the deployed model.
611-
612-
Args:
613-
data (object): Data to be serialized.
614-
615-
Returns:
616-
object: Serialized data used for the request.
617-
"""
618-
if isinstance(data, dict):
619-
# convert each value in dict from a numpy array to a list if necessary, so they can be
620-
# json serialized
621-
return json.dumps({k: _ndarray_to_list(v) for k, v in six.iteritems(data)})
622-
623-
# files and buffers
624-
if hasattr(data, "read"):
625-
return _json_serialize_from_buffer(data)
626-
627-
return json.dumps(_ndarray_to_list(data))
628-
629-
630-
json_serializer = _JsonSerializer()
631-
632-
633-
def _ndarray_to_list(data):
634-
"""
635-
Args:
636-
data:
637-
"""
638-
return data.tolist() if isinstance(data, np.ndarray) else data
639-
640-
641-
def _json_serialize_from_buffer(buff):
642-
"""
643-
Args:
644-
buff:
645-
"""
646-
return buff.read()
647-
648-
649599
class _JsonDeserializer(object):
650600
"""Placeholder docstring"""
651601

src/sagemaker/serializers.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from __future__ import absolute_import
1515

1616
import abc
17+
import json
18+
19+
import numpy as np
1720

1821

1922
class BaseSerializer(abc.ABC):
@@ -38,3 +41,34 @@ def serialize(self, data):
3841
@abc.abstractmethod
3942
def CONTENT_TYPE(self):
4043
"""The MIME type of the data sent to the inference endpoint."""
44+
45+
46+
class JSONSerializer(BaseSerializer):
47+
"""Serialize data to a JSON formatted string."""
48+
49+
CONTENT_TYPE = "application/json"
50+
51+
def serialize(self, data):
52+
"""Serialize data of various formats to a JSON formatted string.
53+
54+
Args:
55+
data (object): Data to be serialized.
56+
57+
Returns:
58+
str: The data serialized as a JSON string.
59+
"""
60+
if isinstance(data, dict):
61+
return json.dumps(
62+
{
63+
key: value.tolist() if isinstance(value, np.ndarray) else value
64+
for key, value in data.items()
65+
}
66+
)
67+
68+
if hasattr(data, "read"):
69+
return data.read()
70+
71+
if isinstance(data, np.ndarray):
72+
return json.dumps(data.tolist())
73+
74+
return json.dumps(data)

src/sagemaker/tensorflow/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import sagemaker
1919
from sagemaker.content_types import CONTENT_TYPE_JSON
2020
from sagemaker.fw_utils import create_image_uri
21-
from sagemaker.predictor import json_serializer, json_deserializer, Predictor
21+
from sagemaker.predictor import json_deserializer, Predictor
22+
from sagemaker.serializers import JSONSerializer
2223

2324

2425
class TensorFlowPredictor(Predictor):
@@ -30,7 +31,7 @@ def __init__(
3031
self,
3132
endpoint_name,
3233
sagemaker_session=None,
33-
serializer=json_serializer,
34+
serializer=JSONSerializer(),
3435
deserializer=json_deserializer,
3536
content_type=None,
3637
model_name=None,

tests/integ/test_inference_pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from sagemaker.content_types import CONTENT_TYPE_CSV
2727
from sagemaker.model import Model
2828
from sagemaker.pipeline import PipelineModel
29-
from sagemaker.predictor import Predictor, json_serializer
29+
from sagemaker.predictor import Predictor
30+
from sagemaker.serializers import JSONSerializer
3031
from sagemaker.sparkml.model import SparkMLModel
3132
from sagemaker.utils import sagemaker_timestamp
3233

@@ -128,7 +129,7 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
128129
predictor = Predictor(
129130
endpoint_name=endpoint_name,
130131
sagemaker_session=sagemaker_session,
131-
serializer=json_serializer,
132+
serializer=JSONSerializer,
132133
content_type=CONTENT_TYPE_CSV,
133134
accept=CONTENT_TYPE_CSV,
134135
)

tests/integ/test_multidatamodel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def test_multi_data_model_deploy_trained_model_from_framework_estimator(
290290
assert PRETRAINED_MODEL_PATH_2 in endpoint_models
291291

292292
# Define a predictor to set `serializer` parameter with npy_serializer
293-
# instead of `json_serializer` in the default predictor returned by `MXNetPredictor`
293+
# instead of `JSONSerializer` in the default predictor returned by `MXNetPredictor`
294294
# Since we are using a placeholder container image the prediction results are not accurate.
295295
predictor = Predictor(
296296
endpoint_name=endpoint_name,
@@ -391,7 +391,7 @@ def test_multi_data_model_deploy_train_model_from_amazon_first_party_estimator(
391391
assert PRETRAINED_MODEL_PATH_2 in endpoint_models
392392

393393
# Define a predictor to set `serializer` parameter with npy_serializer
394-
# instead of `json_serializer` in the default predictor returned by `MXNetPredictor`
394+
# instead of `JSONSerializer` in the default predictor returned by `MXNetPredictor`
395395
# Since we are using a placeholder container image the prediction results are not accurate.
396396
predictor = Predictor(
397397
endpoint_name=endpoint_name,
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
import os
17+
18+
import numpy as np
19+
import pytest
20+
21+
from sagemaker.serializers import JSONSerializer
22+
from tests.unit import DATA_DIR
23+
24+
25+
@pytest.fixture
26+
def json_serializer():
27+
return JSONSerializer()
28+
29+
30+
def test_json_serializer_numpy_valid(json_serializer):
31+
result = json_serializer.serialize(np.array([1, 2, 3]))
32+
33+
assert result == "[1, 2, 3]"
34+
35+
36+
def test_json_serializer_numpy_valid_2dimensional(json_serializer):
37+
result = json_serializer.serialize(np.array([[1, 2, 3], [3, 4, 5]]))
38+
39+
assert result == "[[1, 2, 3], [3, 4, 5]]"
40+
41+
42+
def test_json_serializer_empty(json_serializer):
43+
assert json_serializer.serialize(np.array([])) == "[]"
44+
45+
46+
def test_json_serializer_python_array(json_serializer):
47+
result = json_serializer.serialize([1, 2, 3])
48+
49+
assert result == "[1, 2, 3]"
50+
51+
52+
def test_json_serializer_python_dictionary(json_serializer):
53+
d = {"gender": "m", "age": 22, "city": "Paris"}
54+
55+
result = json_serializer.serialize(d)
56+
57+
assert json.loads(result) == d
58+
59+
60+
def test_json_serializer_python_invalid_empty(json_serializer):
61+
assert json_serializer.serialize([]) == "[]"
62+
63+
64+
def test_json_serializer_python_dictionary_invalid_empty(json_serializer):
65+
assert json_serializer.serialize({}) == "{}"
66+
67+
68+
def test_json_serializer_csv_buffer(json_serializer):
69+
csv_file_path = os.path.join(DATA_DIR, "with_integers.csv")
70+
with open(csv_file_path) as csv_file:
71+
validation_value = csv_file.read()
72+
csv_file.seek(0)
73+
result = json_serializer.serialize(csv_file)
74+
assert result == validation_value

tests/unit/test_predictor.py

Lines changed: 2 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -22,63 +22,16 @@
2222

2323
from sagemaker.predictor import Predictor
2424
from sagemaker.predictor import (
25-
json_serializer,
2625
json_deserializer,
2726
csv_serializer,
2827
npy_serializer,
2928
)
29+
from sagemaker.serializers import JSONSerializer
3030
from tests.unit import DATA_DIR
3131

3232
# testing serialization functions
3333

3434

35-
def test_json_serializer_numpy_valid():
36-
result = json_serializer(np.array([1, 2, 3]))
37-
38-
assert result == "[1, 2, 3]"
39-
40-
41-
def test_json_serializer_numpy_valid_2dimensional():
42-
result = json_serializer(np.array([[1, 2, 3], [3, 4, 5]]))
43-
44-
assert result == "[[1, 2, 3], [3, 4, 5]]"
45-
46-
47-
def test_json_serializer_empty():
48-
assert json_serializer(np.array([])) == "[]"
49-
50-
51-
def test_json_serializer_python_array():
52-
result = json_serializer([1, 2, 3])
53-
54-
assert result == "[1, 2, 3]"
55-
56-
57-
def test_json_serializer_python_dictionary():
58-
d = {"gender": "m", "age": 22, "city": "Paris"}
59-
60-
result = json_serializer(d)
61-
62-
assert json.loads(result) == d
63-
64-
65-
def test_json_serializer_python_invalid_empty():
66-
assert json_serializer([]) == "[]"
67-
68-
69-
def test_json_serializer_python_dictionary_invalid_empty():
70-
assert json_serializer({}) == "{}"
71-
72-
73-
def test_json_serializer_csv_buffer():
74-
csv_file_path = os.path.join(DATA_DIR, "with_integers.csv")
75-
with open(csv_file_path) as csv_file:
76-
validation_value = csv_file.read()
77-
csv_file.seek(0)
78-
result = json_serializer(csv_file)
79-
assert result == validation_value
80-
81-
8235
def test_csv_serializer_str():
8336
original = "1,2,3"
8437
result = csv_serializer("1,2,3")
@@ -388,7 +341,7 @@ def test_predict_call_with_headers_and_json():
388341
sagemaker_session,
389342
content_type="not/json",
390343
accept="also/not-json",
391-
serializer=json_serializer,
344+
serializer=JSONSerializer(),
392345
)
393346

394347
data = [1, 2]

0 commit comments

Comments
 (0)