Skip to content

Commit 81f683e

Browse files
committed
decompressor: add read_across_frames and allow_extra_data arguments to decompress()
This is related to #59 and #181. I'm not willing to implement read_across_frames at this time, as it is non-trivial. But we can implement a placeholder to provide future compatibility with the existing feature. allow_extra_data was trivial to implement, so it works as advertised.
1 parent c7a314e commit 81f683e

File tree

6 files changed

+121
-10
lines changed

6 files changed

+121
-10
lines changed

c-ext/decompressor.c

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,22 +262,39 @@ static PyObject *Decompressor_copy_stream(ZstdDecompressor *self,
262262

263263
PyObject *Decompressor_decompress(ZstdDecompressor *self, PyObject *args,
264264
PyObject *kwargs) {
265-
static char *kwlist[] = {"data", "max_output_size", NULL};
265+
static char *kwlist[] = {
266+
"data",
267+
"max_output_size",
268+
"read_across_frames",
269+
"allow_extra_data",
270+
NULL
271+
};
266272

267273
Py_buffer source;
268274
Py_ssize_t maxOutputSize = 0;
275+
269276
unsigned long long decompressedSize;
277+
PyObject *readAcrossFrames = NULL;
278+
PyObject *allowExtraData = NULL;
270279
size_t destCapacity;
271280
PyObject *result = NULL;
272281
size_t zresult;
273282
ZSTD_outBuffer outBuffer;
274283
ZSTD_inBuffer inBuffer;
275284

276-
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y*|n:decompress", kwlist,
277-
&source, &maxOutputSize)) {
285+
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y*|nOO:decompress", kwlist,
286+
&source, &maxOutputSize, &readAcrossFrames,
287+
&allowExtraData)) {
278288
return NULL;
279289
}
280290

291+
if (readAcrossFrames ? PyObject_IsTrue(readAcrossFrames) : 0) {
292+
PyErr_SetString(ZstdError,
293+
"ZstdDecompressor.read_across_frames=True is not yet implemented"
294+
);
295+
goto finally;
296+
}
297+
281298
if (ensure_dctx(self, 1)) {
282299
goto finally;
283300
}
@@ -361,6 +378,16 @@ PyObject *Decompressor_decompress(ZstdDecompressor *self, PyObject *args,
361378
goto finally;
362379
}
363380
}
381+
else if ((allowExtraData ? PyObject_IsTrue(allowExtraData) : 1) == 0
382+
&& inBuffer.pos < inBuffer.size) {
383+
PyErr_Format(
384+
ZstdError,
385+
"compressed input contains %zu bytes of unused data, which is disallowed",
386+
inBuffer.size - inBuffer.pos
387+
);
388+
Py_CLEAR(result);
389+
goto finally;
390+
}
364391

365392
finally:
366393
PyBuffer_Release(&source);

docs/news.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,17 @@ Changes
103103
This may have fixed unconfirmed issues where ``unused_data`` was set
104104
prematurely. The new logic will also avoid an extra call to
105105
``ZSTD_decompressStream()`` in some scenarios, possibly improving performance.
106+
* ``ZstdDecompressor.decompress()`` how has a ``read_across_frames`` keyword
107+
argument. It defaults to False. True is not yet implemented and will raise an
108+
exception if used. The new argument will default to True in a future release
109+
and is provided now so callers can start passing ``read_across_frames=False``
110+
to preserve the existing functionality during a future upgrade.
111+
* ``ZstdDecompressor.decompress()`` now has an ``allow_extra_data`` keyword
112+
argument to control whether an exception is raised if input contains extra
113+
data. It defaults to True, preserving existing behavior of ignoring extra
114+
data. It will likely default to False in a future release. Callers desiring
115+
the current behavior are encouraged to explicitly pass
116+
``allow_extra_data=True`` so behavior won't change during a future upgrade.
106117

107118
0.18.0 (released 2022-06-20)
108119
============================

rust-ext/src/decompressor.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,26 @@ impl ZstdDecompressor {
159159
Ok((total_read, total_write))
160160
}
161161

162-
#[args(buffer, max_output_size = "0")]
162+
#[args(
163+
buffer,
164+
max_output_size = "0",
165+
read_across_frames = "false",
166+
allow_extra_data = "true"
167+
)]
163168
fn decompress<'p>(
164169
&mut self,
165170
py: Python<'p>,
166171
buffer: PyBuffer<u8>,
167172
max_output_size: usize,
173+
read_across_frames: bool,
174+
allow_extra_data: bool,
168175
) -> PyResult<&'p PyBytes> {
176+
if read_across_frames {
177+
return Err(ZstdError::new_err(
178+
"ZstdDecompressor.read_across_frames=True is not yet implemented",
179+
));
180+
}
181+
169182
self.setup_dctx(py, true)?;
170183

171184
let output_size =
@@ -215,6 +228,11 @@ impl ZstdDecompressor {
215228
"decompression error: decompressed {} bytes; expected {}",
216229
zresult, output_size
217230
)))
231+
} else if !allow_extra_data && in_buffer.pos < in_buffer.size {
232+
Err(ZstdError::new_err(format!(
233+
"compressed input contains {} bytes of unused data, which is disallowed",
234+
in_buffer.size - in_buffer.pos
235+
)))
218236
} else {
219237
// TODO avoid memory copy
220238
Ok(PyBytes::new(py, &dest_buffer))

tests/test_decompressor_decompress.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,31 @@ def test_multiple_frames(self):
183183

184184
dctx = zstd.ZstdDecompressor()
185185
self.assertEqual(dctx.decompress(foo + bar), b"foo")
186+
self.assertEqual(
187+
dctx.decompress(foo + bar, allow_extra_data=True), b"foo"
188+
)
189+
190+
with self.assertRaisesRegex(
191+
zstd.ZstdError,
192+
"ZstdDecompressor.read_across_frames=True is not yet implemented",
193+
):
194+
dctx.decompress(foo + bar, read_across_frames=True)
195+
196+
with self.assertRaisesRegex(
197+
zstd.ZstdError, "%d bytes of unused data, which is disallowed" % len(bar)
198+
):
199+
dctx.decompress(foo + bar, allow_extra_data=False)
186200

187201
def test_junk_after_frame(self):
188202
cctx = zstd.ZstdCompressor()
189203
frame = cctx.compress(b"foo")
190204

191205
dctx = zstd.ZstdDecompressor()
192206
self.assertEqual(dctx.decompress(frame + b"junk"), b"foo")
207+
208+
self.assertEqual(dctx.decompress(frame + b"junk", allow_extra_data=True), b"foo")
209+
210+
with self.assertRaisesRegex(
211+
zstd.ZstdError, "4 bytes of unused data, which is disallowed"
212+
):
213+
dctx.decompress(frame + b"junk", allow_extra_data=False)

zstandard/__init__.pyi

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,11 @@ class ZstdDecompressor(object):
389389
): ...
390390
def memory_size(self) -> int: ...
391391
def decompress(
392-
self, data: ByteString, max_output_size: int = ...
392+
self,
393+
data: ByteString,
394+
max_output_size: int = ...,
395+
read_across_frames: bool = ...,
396+
allow_extra_data: bool = ...,
393397
) -> bytes: ...
394398
def stream_reader(
395399
self,

zstandard/backend_cffi.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3006,7 +3006,10 @@ def decompress(self, data):
30063006
# buffer. So if the output buffer is partially filled and the input
30073007
# is exhausted, there's nothing more to write. So we've done all we
30083008
# can.
3009-
elif in_buffer.pos == in_buffer.size and out_buffer.pos < out_buffer.size:
3009+
elif (
3010+
in_buffer.pos == in_buffer.size
3011+
and out_buffer.pos < out_buffer.size
3012+
):
30103013
break
30113014
else:
30123015
out_buffer.pos = 0
@@ -3715,7 +3718,13 @@ def memory_size(self):
37153718
"""
37163719
return lib.ZSTD_sizeof_DCtx(self._dctx)
37173720

3718-
def decompress(self, data, max_output_size=0):
3721+
def decompress(
3722+
self,
3723+
data,
3724+
max_output_size=0,
3725+
read_across_frames=False,
3726+
allow_extra_data=True,
3727+
):
37193728
"""
37203729
Decompress data in a single operation.
37213730
@@ -3727,11 +3736,20 @@ def decompress(self, data, max_output_size=0):
37273736
similar). If the input does not contain a full frame, an exception will
37283737
be raised.
37293738
3730-
If the input contains multiple frames, only the first frame will be
3731-
decompressed. If you need to decompress multiple frames, use an API
3732-
like :py:meth:`ZstdCompressor.stream_reader` with
3739+
``read_across_frames`` controls whether to read multiple zstandard
3740+
frames in the input. When False, decompression stops after reading the
3741+
first frame. This feature is not yet implemented but the argument is
3742+
provided for forward API compatibility when the default is changed to
3743+
True in a future release. For now, if you need to decompress multiple
3744+
frames, use an API like :py:meth:`ZstdCompressor.stream_reader` with
37333745
``read_across_frames=True``.
37343746
3747+
``allow_extra_data`` controls how to handle extra input data after a
3748+
fully decoded frame. If False, any extra data (which could be a valid
3749+
zstd frame) will result in ``ZstdError`` being raised. If True, extra
3750+
data is silently ignored. The default will likely change to False in a
3751+
future release when ``read_across_frames`` defaults to True.
3752+
37353753
If the input contains extra data after a full frame, that extra input
37363754
data is silently ignored. This behavior is undesirable in many scenarios
37373755
and will likely be changed or controllable in a future release (see
@@ -3783,6 +3801,11 @@ def decompress(self, data, max_output_size=0):
37833801
``bytes`` representing decompressed output.
37843802
"""
37853803

3804+
if read_across_frames:
3805+
raise ZstdError(
3806+
"ZstdDecompressor.read_across_frames=True is not yet implemented"
3807+
)
3808+
37863809
self._ensure_dctx()
37873810

37883811
data_buffer = ffi.from_buffer(data)
@@ -3830,6 +3853,13 @@ def decompress(self, data, max_output_size=0):
38303853
"decompression error: decompressed %d bytes; expected %d"
38313854
% (zresult, output_size)
38323855
)
3856+
elif not allow_extra_data and in_buffer.pos < in_buffer.size:
3857+
count = in_buffer.size - in_buffer.pos
3858+
3859+
raise ZstdError(
3860+
"compressed input contains %d bytes of unused data, which is disallowed"
3861+
% count
3862+
)
38333863

38343864
return ffi.buffer(result_buffer, out_buffer.pos)[:]
38353865

0 commit comments

Comments
 (0)