Skip to content

Commit f0c6f93

Browse files
Rodrigo Zhouindygreg
authored andcommitted
cext: accept explicit None value in ZstdCompressor/ZstdDecompressor constructors
Closes #153.
1 parent 9ac5b4d commit f0c6f93

File tree

6 files changed

+154
-20
lines changed

6 files changed

+154
-20
lines changed

c-ext/compressor.c

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,16 @@ static int ZstdCompressor_init(ZstdCompressor *self, PyObject *args,
9797
NULL};
9898

9999
int level = 3;
100-
ZstdCompressionDict *dict = NULL;
101-
ZstdCompressionParametersObject *params = NULL;
100+
PyObject *dict = NULL;
101+
PyObject *params = NULL;
102102
PyObject *writeChecksum = NULL;
103103
PyObject *writeContentSize = NULL;
104104
PyObject *writeDictID = NULL;
105105
int threads = 0;
106106

107-
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|iO!O!OOOi:ZstdCompressor",
108-
kwlist, &level, &ZstdCompressionDictType,
109-
&dict, &ZstdCompressionParametersType,
110-
&params, &writeChecksum, &writeContentSize,
107+
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|iOOOOOi:ZstdCompressor",
108+
kwlist, &level, &dict, &params,
109+
&writeChecksum, &writeContentSize,
111110
&writeDictID, &threads)) {
112111
return -1;
113112
}
@@ -122,6 +121,41 @@ static int ZstdCompressor_init(ZstdCompressor *self, PyObject *args,
122121
threads = cpu_count();
123122
}
124123

124+
if (dict) {
125+
if (dict == Py_None) {
126+
dict = NULL;
127+
}
128+
else if (!PyObject_IsInstance(dict,
129+
(PyObject *)&ZstdCompressionDictType)) {
130+
PyErr_Format(PyExc_TypeError,
131+
"dict_data must be zstd.ZstdCompressionDict");
132+
return -1;
133+
}
134+
}
135+
136+
if (params) {
137+
if (params == Py_None) {
138+
params = NULL;
139+
}
140+
else if (!PyObject_IsInstance(
141+
params, (PyObject *)&ZstdCompressionParametersType)) {
142+
PyErr_Format(
143+
PyExc_TypeError,
144+
"compression_params must be zstd.ZstdCompressionParameters");
145+
return -1;
146+
}
147+
}
148+
149+
if (writeChecksum == Py_None) {
150+
writeChecksum = NULL;
151+
}
152+
if (writeContentSize == Py_None) {
153+
writeContentSize = NULL;
154+
}
155+
if (writeDictID == Py_None) {
156+
writeDictID = NULL;
157+
}
158+
125159
/* We create a ZSTD_CCtx for reuse among multiple operations to reduce the
126160
overhead of each compression operation. */
127161
self->cctx = ZSTD_createCCtx();
@@ -166,7 +200,8 @@ static int ZstdCompressor_init(ZstdCompressor *self, PyObject *args,
166200
}
167201

168202
if (params) {
169-
if (set_parameters(self->params, params)) {
203+
if (set_parameters(self->params,
204+
(ZstdCompressionParametersObject *)params)) {
170205
return -1;
171206
}
172207
}
@@ -199,7 +234,7 @@ static int ZstdCompressor_init(ZstdCompressor *self, PyObject *args,
199234
}
200235

201236
if (dict) {
202-
self->dict = dict;
237+
self->dict = (ZstdCompressionDict *)dict;
203238
Py_INCREF(dict);
204239
}
205240

c-ext/decompressor.c

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ int ensure_dctx(ZstdDecompressor *decompressor, int loadDict) {
2929
}
3030
}
3131

32-
zresult = ZSTD_DCtx_setParameter(decompressor->dctx, ZSTD_d_format, decompressor->format);
32+
zresult = ZSTD_DCtx_setParameter(decompressor->dctx, ZSTD_d_format,
33+
decompressor->format);
3334
if (ZSTD_isError(zresult)) {
3435
PyErr_Format(ZstdError, "unable to set decoding format: %s",
3536
ZSTD_getErrorName(zresult));
@@ -58,19 +59,30 @@ static int Decompressor_init(ZstdDecompressor *self, PyObject *args,
5859
PyObject *kwargs) {
5960
static char *kwlist[] = {"dict_data", "max_window_size", "format", NULL};
6061

61-
ZstdCompressionDict *dict = NULL;
62+
PyObject *dict = NULL;
6263
Py_ssize_t maxWindowSize = 0;
6364
ZSTD_format_e format = ZSTD_f_zstd1;
6465

6566
self->dctx = NULL;
6667
self->dict = NULL;
6768

68-
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!nI:ZstdDecompressor",
69-
kwlist, &ZstdCompressionDictType, &dict,
70-
&maxWindowSize, &format)) {
69+
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OnI:ZstdDecompressor",
70+
kwlist, &dict, &maxWindowSize, &format)) {
7171
return -1;
7272
}
7373

74+
if (dict) {
75+
if (dict == Py_None) {
76+
dict = NULL;
77+
}
78+
else if (!PyObject_IsInstance(dict,
79+
(PyObject *)&ZstdCompressionDictType)) {
80+
PyErr_Format(PyExc_TypeError,
81+
"dict_data must be zstd.ZstdCompressionDict");
82+
return -1;
83+
}
84+
}
85+
7486
self->dctx = ZSTD_createDCtx();
7587
if (!self->dctx) {
7688
PyErr_NoMemory();
@@ -81,7 +93,7 @@ static int Decompressor_init(ZstdDecompressor *self, PyObject *args,
8193
self->format = format;
8294

8395
if (dict) {
84-
self->dict = dict;
96+
self->dict = (ZstdCompressionDict *)dict;
8597
Py_INCREF(dict);
8698
}
8799

docs/news.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ Changes
9595
* ``manylinux2014_aarch64`` wheels are now being produced for CPython 3.6+.
9696
(#145).
9797
* Wheels are now being produced for CPython 3.10.
98+
* Arguments to ``ZstdCompressor()`` and ``ZstdDecompressor()`` are now all
99+
optional in the C backend and an explicit ``None`` value is accepted. Before,
100+
the C backend wouldn't accept an explicit ``None`` value (but the CFFI
101+
backend would). The new behavior should be consistent between the backends.
102+
(#153)
98103

99104
0.15.2 (released 2021-02-27)
100105
============================

tests/test_compressor_compress.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,74 @@ def test_multithreaded_compression_params(self):
213213
self.assertEqual(
214214
result, b"\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f"
215215
)
216+
217+
def test_explicit_default_params(self):
218+
cctx = zstd.ZstdCompressor(
219+
level=3,
220+
dict_data=None,
221+
compression_params=None,
222+
write_checksum=None,
223+
write_content_size=None,
224+
write_dict_id=None,
225+
threads=0,
226+
)
227+
result = cctx.compress(b"")
228+
self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x00\x01\x00\x00")
229+
230+
def test_compression_params_with_other_params(self):
231+
params = zstd.ZstdCompressionParameters.from_level(3)
232+
cctx = zstd.ZstdCompressor(
233+
level=3,
234+
dict_data=None,
235+
compression_params=params,
236+
write_checksum=None,
237+
write_content_size=None,
238+
write_dict_id=None,
239+
threads=0,
240+
)
241+
result = cctx.compress(b"")
242+
self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x00\x01\x00\x00")
243+
244+
with self.assertRaises(ValueError):
245+
cctx = zstd.ZstdCompressor(
246+
level=3,
247+
dict_data=None,
248+
compression_params=params,
249+
write_checksum=False,
250+
write_content_size=None,
251+
write_dict_id=None,
252+
threads=0,
253+
)
254+
255+
with self.assertRaises(ValueError):
256+
cctx = zstd.ZstdCompressor(
257+
level=3,
258+
dict_data=None,
259+
compression_params=params,
260+
write_checksum=None,
261+
write_content_size=True,
262+
write_dict_id=None,
263+
threads=0,
264+
)
265+
266+
with self.assertRaises(ValueError):
267+
cctx = zstd.ZstdCompressor(
268+
level=3,
269+
dict_data=None,
270+
compression_params=params,
271+
write_checksum=None,
272+
write_content_size=None,
273+
write_dict_id=True,
274+
threads=0,
275+
)
276+
277+
with self.assertRaises(ValueError):
278+
cctx = zstd.ZstdCompressor(
279+
level=3,
280+
dict_data=None,
281+
compression_params=params,
282+
write_checksum=None,
283+
write_content_size=None,
284+
write_dict_id=True,
285+
threads=2,
286+
)

tests/test_decompressor_decompress.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,14 @@ def test_max_window_size(self):
164164
"decompression error: Frame requires too much memory",
165165
):
166166
dctx.decompress(frame, max_output_size=len(source))
167+
168+
def test_explicit_default_params(self):
169+
cctx = zstd.ZstdCompressor(level=1)
170+
compressed = cctx.compress(b"foo")
171+
172+
dctx = zstd.ZstdDecompressor(
173+
dict_data=None,
174+
max_window_size=0,
175+
format=zstd.FORMAT_ZSTD1,
176+
)
177+
self.assertEqual(dctx.decompress(compressed), b"foo")

zstandard/__init__.pyi

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,11 @@ class ZstdCompressor(object):
265265
def __init__(
266266
self,
267267
level: int = ...,
268-
dict_data: ZstdCompressionDict = ...,
269-
compression_params: ZstdCompressionParameters = ...,
270-
write_checksum: bool = ...,
271-
write_content_size: bool = ...,
272-
write_dict_id: bool = ...,
268+
dict_data: Optional[ZstdCompressionDict] = ...,
269+
compression_params: Optional[ZstdCompressionParameters] = ...,
270+
write_checksum: Optional[bool] = ...,
271+
write_content_size: Optional[bool] = ...,
272+
write_dict_id: Optional[bool] = ...,
273273
threads: int = ...,
274274
): ...
275275
def memory_size(self) -> int: ...
@@ -376,7 +376,7 @@ class ZstdDecompressionWriter(BinaryIO):
376376
class ZstdDecompressor(object):
377377
def __init__(
378378
self,
379-
dict_data: ZstdCompressionDict = ...,
379+
dict_data: Optional[ZstdCompressionDict] = ...,
380380
max_window_size: int = ...,
381381
format: int = ...,
382382
): ...

0 commit comments

Comments
 (0)