Skip to content

Commit 06064f8

Browse files
author
Ashwin Krishna
committed
breaking: fix unit tests and linting
NumpyDeserializer will not allow deserialization unless allow_pickle flag is set to True explicitly
1 parent cee7f30 commit 06064f8

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

src/sagemaker/base_deserializers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,15 @@ def deserialize(self, stream, content_type):
229229
if content_type == "application/x-npy":
230230
try:
231231
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
232-
except ValueError as ve:
233-
error_message = "Please set the param allow_pickle=True to deserialize pickle objects" + str(ve)
234-
raise ValueError(error_message)
232+
except ValueError as ve:
233+
raise ValueError("Please set the param allow_pickle=True \
234+
to deserialize pickle objects in NumpyDeserializer").with_traceback(ve.__traceback__)
235235
if content_type == "application/x-npz":
236236
try:
237237
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
238-
except ValueError as ve:
239-
error_message = "Please set the param allow_pickle=True to deserialize pickle objects" + str(ve)
240-
raise ValueError(error_message)
238+
except ValueError as ve:
239+
raise ValueError("Please set the param allow_pickle=True \
240+
to deserialize pickle objectsin NumpyDeserializer").with_traceback(ve.__traceback__)
241241
finally:
242242
stream.close()
243243
finally:

tests/unit/sagemaker/deserializers/test_deserializers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def test_numpy_deserializer_from_npy(numpy_deserializer):
142142
assert np.array_equal(array, result)
143143

144144

145-
def test_numpy_deserializer_from_npy_object_array(numpy_deserializer):
145+
def test_numpy_deserializer_from_npy_object_array():
146+
numpy_deserializer = NumpyDeserializer(allow_pickle=True)
146147
array = np.array([{"a": "", "b": ""}, {"c": "", "d": ""}])
147148
stream = io.BytesIO()
148149
np.save(stream, array)

0 commit comments

Comments
 (0)