Skip to content

Commit 31365d2

Browse files
SSRraymondRaymond Liu
authored andcommitted
fix: Create CSVSerializerWrapper - Galactus (#1383)
Co-authored-by: Raymond Liu <[email protected]>
1 parent c5444e0 commit 31365d2

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

src/sagemaker/serve/builder/schema_builder.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,20 @@
3434
class JSONSerializerWrapper(JSONSerializer):
3535
"""Wraps the JSONSerializer because it does not convert jsonable to bytes"""
3636

37-
def serialize(self, data):
37+
def serialize(self, data) -> bytes:
3838
"""Placeholder docstring"""
3939

4040
return super().serialize(data).encode("utf-8")
4141

4242

43+
class CSVSerializerWrapper(CSVSerializer):
44+
"""Wraps the CSVSerializer because it does not convert dataframe to bytes"""
45+
46+
def serialize(self, data) -> bytes:
47+
"""Placeholder docstring"""
48+
return super().serialize(data).encode("utf-8")
49+
50+
4351
translation_mapping = {
4452
NumpySerializer: NumpyDeserializer,
4553
NumpyDeserializer: NumpySerializer,
@@ -49,8 +57,8 @@ def serialize(self, data):
4957
TorchTensorDeserializer: TorchTensorSerializer,
5058
DataSerializer: BytesDeserializer,
5159
BytesDeserializer: DataSerializer,
52-
CSVSerializer: PandasDeserializer,
53-
PandasDeserializer: CSVSerializer,
60+
CSVSerializerWrapper: PandasDeserializer,
61+
PandasDeserializer: CSVSerializerWrapper,
5462
StringSerializer: StringDeserializer,
5563
StringDeserializer: StringSerializer,
5664
}
@@ -147,7 +155,7 @@ def _get_serializer(self, obj):
147155
if isinstance(obj, np.ndarray):
148156
return NumpySerializer()
149157
if isinstance(obj, DataFrame):
150-
return CSVSerializer()
158+
return CSVSerializerWrapper()
151159
if isinstance(obj, bytes) or _is_path_to_file(obj):
152160
return DataSerializer()
153161
if _is_torch_tensor(obj):

tests/unit/sagemaker/serve/builder/test_schema_builder.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pandas import DataFrame
2121

2222
from sagemaker.serve import SchemaBuilder, CustomPayloadTranslator
23-
from sagemaker.serve.builder.schema_builder import JSONSerializerWrapper
23+
from sagemaker.serve.builder.schema_builder import JSONSerializerWrapper, CSVSerializerWrapper
2424
from sagemaker.deserializers import (
2525
BytesDeserializer,
2626
NumpyDeserializer,
@@ -30,7 +30,6 @@
3030
from sagemaker.serializers import (
3131
DataSerializer,
3232
NumpySerializer,
33-
CSVSerializer,
3433
)
3534

3635
NUMPY_CONTENT_TYPE = "application/x-npy"
@@ -94,13 +93,9 @@ def custom_translator():
9493
return MyPayloadTranslator()
9594

9695

97-
# @pytest.fixture
98-
# def torch_tensor():
99-
# return torch.rand(3, 4)
100-
101-
10296
def test_schema_builder_with_numpy(numpy_array):
10397
schema_builder = SchemaBuilder(numpy_array, numpy_array)
98+
_validate_marshalling_function(schema_builder=schema_builder)
10499
assert isinstance(schema_builder.input_serializer, NumpySerializer)
105100
assert isinstance(schema_builder.output_serializer, NumpySerializer)
106101
assert isinstance(schema_builder.input_deserializer._deserializer, NumpyDeserializer)
@@ -111,8 +106,9 @@ def test_schema_builder_with_numpy(numpy_array):
111106

112107
def test_schema_builder_with_pandas_dataframe(pandas_df):
113108
schema_builder = SchemaBuilder(pandas_df, pandas_df)
114-
assert isinstance(schema_builder.input_serializer, CSVSerializer)
115-
assert isinstance(schema_builder.output_serializer, CSVSerializer)
109+
_validate_marshalling_function(schema_builder=schema_builder)
110+
assert isinstance(schema_builder.input_serializer, CSVSerializerWrapper)
111+
assert isinstance(schema_builder.output_serializer, CSVSerializerWrapper)
116112
assert isinstance(schema_builder.input_deserializer._deserializer, PandasDeserializer)
117113
assert schema_builder.input_deserializer.ACCEPT == DATAFRAME_CONTENT_TYPE
118114
assert isinstance(schema_builder.output_deserializer._deserializer, PandasDeserializer)
@@ -121,6 +117,7 @@ def test_schema_builder_with_pandas_dataframe(pandas_df):
121117

122118
def test_schema_builder_with_jsonable(jsonable_obj):
123119
schema_builder = SchemaBuilder(jsonable_obj, jsonable_obj)
120+
_validate_marshalling_function(schema_builder=schema_builder)
124121
assert isinstance(schema_builder.input_serializer, JSONSerializerWrapper)
125122
assert isinstance(schema_builder.output_serializer, JSONSerializerWrapper)
126123
assert isinstance(schema_builder.input_deserializer._deserializer, JSONDeserializer)
@@ -131,13 +128,14 @@ def test_schema_builder_with_jsonable(jsonable_obj):
131128

132129
def test_schema_builder_with_bytes(some_bytes):
133130
schema_builder = SchemaBuilder(some_bytes, some_bytes)
131+
_validate_marshalling_function(schema_builder=schema_builder)
134132
assert isinstance(schema_builder.input_serializer, DataSerializer)
135133
assert isinstance(schema_builder.output_serializer, DataSerializer)
136134
assert isinstance(schema_builder.input_deserializer._deserializer, BytesDeserializer)
137135
assert isinstance(schema_builder.output_deserializer._deserializer, BytesDeserializer)
138136

139137

140-
def test_schema_builder_with_cloudpickle(unsupported_object):
138+
def test_schema_builder_unsupported_type(unsupported_object):
141139
with pytest.raises(ValueError, match="SchemaBuilder cannot determine"):
142140
SchemaBuilder(unsupported_object, unsupported_object)
143141

@@ -149,6 +147,19 @@ def test_json_serializer_wrapper(jsonable):
149147
JSONDeserializer().deserialize(stream, content_type="application/json")
150148

151149

150+
def _validate_marshalling_function(schema_builder: SchemaBuilder):
151+
"""Invoke serializer and deserializer to validate the payload"""
152+
# Validate sample_input
153+
b = schema_builder.input_serializer.serialize(schema_builder.sample_input)
154+
stream = BytesIO(b)
155+
schema_builder.input_deserializer.deserialize(stream=stream)
156+
157+
# Validate sample_output
158+
b = schema_builder.output_serializer.serialize(schema_builder.sample_output)
159+
stream = BytesIO(b)
160+
schema_builder.output_deserializer.deserialize(stream=stream)
161+
162+
152163
def test_schema_builder_with_payload_translator(custom_translator):
153164
payload = "payload"
154165
schema_builder = SchemaBuilder(

0 commit comments

Comments
 (0)