Skip to content

Commit cee7f30

Browse files
author
Ashwin Krishna
committed
breaking: set default allow_pickle param to False
1 parent 58717d2 commit cee7f30

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/sagemaker/base_deserializers.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,14 @@ class NumpyDeserializer(SimpleBaseDeserializer):
196196
single array.
197197
"""
198198

199-
def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True):
199+
def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=False):
200200
"""Initialize a ``NumpyDeserializer`` instance.
201201
202202
Args:
203203
dtype (str): The dtype of the data (default: None).
204204
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
205205
is expected from the inference endpoint (default: "application/x-npy").
206-
allow_pickle (bool): Allow loading pickled object arrays (default: True).
206+
allow_pickle (bool): Allow loading pickled object arrays (default: False).
207207
"""
208208
super(NumpyDeserializer, self).__init__(accept=accept)
209209
self.dtype = dtype
@@ -227,10 +227,17 @@ def deserialize(self, stream, content_type):
227227
if content_type == "application/json":
228228
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
229229
if content_type == "application/x-npy":
230-
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
230+
try:
231+
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
232+
except ValueError as ve:
233+
error_message = "Please set the param allow_pickle=True to deserialize pickle objects" + str(ve)
234+
raise ValueError(error_message)
231235
if content_type == "application/x-npz":
232236
try:
233237
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
238+
except ValueError as ve:
239+
error_message = "Please set the param allow_pickle=True to deserialize pickle objects" + str(ve)
240+
raise ValueError(error_message)
234241
finally:
235242
stream.close()
236243
finally:

0 commit comments

Comments
 (0)