20
20
from pandas import DataFrame
21
21
22
22
from sagemaker .serve import SchemaBuilder , CustomPayloadTranslator
23
- from sagemaker .serve .builder .schema_builder import JSONSerializerWrapper
23
+ from sagemaker .serve .builder .schema_builder import JSONSerializerWrapper , CSVSerializerWrapper
24
24
from sagemaker .deserializers import (
25
25
BytesDeserializer ,
26
26
NumpyDeserializer ,
30
30
from sagemaker .serializers import (
31
31
DataSerializer ,
32
32
NumpySerializer ,
33
- CSVSerializer ,
34
33
)
35
34
36
35
NUMPY_CONTENT_TYPE = "application/x-npy"
@@ -94,13 +93,9 @@ def custom_translator():
94
93
return MyPayloadTranslator ()
95
94
96
95
97
- # @pytest.fixture
98
- # def torch_tensor():
99
- # return torch.rand(3, 4)
100
-
101
-
102
96
def test_schema_builder_with_numpy (numpy_array ):
103
97
schema_builder = SchemaBuilder (numpy_array , numpy_array )
98
+ _validate_marshalling_function (schema_builder = schema_builder )
104
99
assert isinstance (schema_builder .input_serializer , NumpySerializer )
105
100
assert isinstance (schema_builder .output_serializer , NumpySerializer )
106
101
assert isinstance (schema_builder .input_deserializer ._deserializer , NumpyDeserializer )
@@ -111,8 +106,9 @@ def test_schema_builder_with_numpy(numpy_array):
111
106
112
107
def test_schema_builder_with_pandas_dataframe (pandas_df ):
113
108
schema_builder = SchemaBuilder (pandas_df , pandas_df )
114
- assert isinstance (schema_builder .input_serializer , CSVSerializer )
115
- assert isinstance (schema_builder .output_serializer , CSVSerializer )
109
+ _validate_marshalling_function (schema_builder = schema_builder )
110
+ assert isinstance (schema_builder .input_serializer , CSVSerializerWrapper )
111
+ assert isinstance (schema_builder .output_serializer , CSVSerializerWrapper )
116
112
assert isinstance (schema_builder .input_deserializer ._deserializer , PandasDeserializer )
117
113
assert schema_builder .input_deserializer .ACCEPT == DATAFRAME_CONTENT_TYPE
118
114
assert isinstance (schema_builder .output_deserializer ._deserializer , PandasDeserializer )
@@ -121,6 +117,7 @@ def test_schema_builder_with_pandas_dataframe(pandas_df):
121
117
122
118
def test_schema_builder_with_jsonable (jsonable_obj ):
123
119
schema_builder = SchemaBuilder (jsonable_obj , jsonable_obj )
120
+ _validate_marshalling_function (schema_builder = schema_builder )
124
121
assert isinstance (schema_builder .input_serializer , JSONSerializerWrapper )
125
122
assert isinstance (schema_builder .output_serializer , JSONSerializerWrapper )
126
123
assert isinstance (schema_builder .input_deserializer ._deserializer , JSONDeserializer )
@@ -131,13 +128,14 @@ def test_schema_builder_with_jsonable(jsonable_obj):
131
128
132
129
def test_schema_builder_with_bytes (some_bytes ):
133
130
schema_builder = SchemaBuilder (some_bytes , some_bytes )
131
+ _validate_marshalling_function (schema_builder = schema_builder )
134
132
assert isinstance (schema_builder .input_serializer , DataSerializer )
135
133
assert isinstance (schema_builder .output_serializer , DataSerializer )
136
134
assert isinstance (schema_builder .input_deserializer ._deserializer , BytesDeserializer )
137
135
assert isinstance (schema_builder .output_deserializer ._deserializer , BytesDeserializer )
138
136
139
137
140
- def test_schema_builder_with_cloudpickle (unsupported_object ):
138
+ def test_schema_builder_unsupported_type (unsupported_object ):
141
139
with pytest .raises (ValueError , match = "SchemaBuilder cannot determine" ):
142
140
SchemaBuilder (unsupported_object , unsupported_object )
143
141
@@ -149,6 +147,19 @@ def test_json_serializer_wrapper(jsonable):
149
147
JSONDeserializer ().deserialize (stream , content_type = "application/json" )
150
148
151
149
150
+ def _validate_marshalling_function (schema_builder : SchemaBuilder ):
151
+ """Invoke serializer and deserializer to validate the payload"""
152
+ # Validate sample_input
153
+ b = schema_builder .input_serializer .serialize (schema_builder .sample_input )
154
+ stream = BytesIO (b )
155
+ schema_builder .input_deserializer .deserialize (stream = stream )
156
+
157
+ # Validate sample_output
158
+ b = schema_builder .output_serializer .serialize (schema_builder .sample_output )
159
+ stream = BytesIO (b )
160
+ schema_builder .output_deserializer .deserialize (stream = stream )
161
+
162
+
152
163
def test_schema_builder_with_payload_translator (custom_translator ):
153
164
payload = "payload"
154
165
schema_builder = SchemaBuilder (
0 commit comments