Skip to content

Commit 8b6a33a

Browse files
authored
Merge pull request #5 from jreback/pickle
ENH: add Pickle/MsgPack codec with support for object ndarrays
2 parents 31fd4e0 + 214ef94 commit 8b6a33a

File tree

8 files changed

+261
-1
lines changed

8 files changed

+261
-1
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ __pycache__/
66
# C extensions
77
*.so
88

9+
# editor
10+
*~
11+
912
# Distribution / packaging
1013
.Python
1114
env/
@@ -44,6 +47,7 @@ nosetests.xml
4447
coverage.xml
4548
*,cover
4649
.hypothesis/
50+
cover/
4751

4852
# Translations
4953
*.mo

numcodecs/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@
6161
from numcodecs.categorize import Categorize
6262
register_codec(Categorize)
6363

64+
from numcodecs.pickles import Pickle
65+
register_codec(Pickle)
66+
67+
try:
68+
from numcodecs.msgpacks import MsgPack
69+
register_codec(MsgPack)
70+
except ImportError: # pragma: no cover
71+
pass
6472

6573
from numcodecs.checksum32 import CRC32, Adler32
6674
register_codec(CRC32)

numcodecs/msgpacks.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import absolute_import, print_function, division
3+
4+
5+
import numpy as np
6+
7+
8+
from numcodecs.abc import Codec
9+
from numcodecs.compat import ndarray_from_buffer, buffer_copy
10+
import msgpack
11+
12+
13+
class MsgPack(Codec):
14+
"""Codec to encode data as msgpacked bytes. Useful for encoding python
15+
strings
16+
17+
Raises
18+
------
19+
encoding a non-object dtyped ndarray will raise ValueError
20+
21+
Examples
22+
--------
23+
>>> import numcodecs as codecs
24+
>>> import numpy as np
25+
>>> x = np.array(['foo', 'bar', 'baz'], dtype='object')
26+
>>> f = codecs.MsgPack()
27+
>>> f.decode(f.encode(x))
28+
array(['foo', 'bar', 'baz'], dtype=object)
29+
30+
""" # flake8: noqa
31+
32+
codec_id = 'msgpack'
33+
34+
def encode(self, buf):
35+
if hasattr(buf, 'dtype') and buf.dtype != 'object':
36+
raise ValueError("cannot encode non-object ndarrays, %s "
37+
"dtype was passed" % buf.dtype)
38+
return msgpack.packb(buf.tolist(), encoding='utf-8')
39+
40+
def decode(self, buf, out=None):
41+
dec = np.array(msgpack.unpackb(buf, encoding='utf-8'), dtype='object')
42+
if out is not None:
43+
np.copyto(out, dec)
44+
return out
45+
else:
46+
return dec
47+
48+
def get_config(self):
49+
return dict(id=self.codec_id)
50+
51+
def __repr__(self):
52+
return 'MsgPack()'

numcodecs/pickles.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import absolute_import, print_function, division
3+
4+
5+
import numpy as np
6+
7+
8+
from numcodecs.abc import Codec
9+
from numcodecs.compat import ndarray_from_buffer, buffer_copy
10+
try:
11+
import cPickle as pickle
12+
except ImportError:
13+
import pickle
14+
15+
16+
class Pickle(Codec):
17+
"""Codec to encode data as as pickled bytes. Useful for encoding python
18+
strings.
19+
20+
Parameters
21+
----------
22+
protocol : int, defaults to pickle.HIGHEST_PROTOCOL
23+
the protocol used to pickle data
24+
25+
Raises
26+
------
27+
encoding a non-object dtyped ndarray will raise ValueError
28+
29+
Examples
30+
--------
31+
>>> import numcodecs as codecs
32+
>>> import numpy as np
33+
>>> x = np.array(['foo', 'bar', 'baz'], dtype='object')
34+
>>> f = codecs.Pickle()
35+
>>> f.decode(f.encode(x))
36+
array(['foo', 'bar', 'baz'], dtype=object)
37+
38+
""" # flake8: noqa
39+
40+
codec_id = 'pickle'
41+
42+
def __init__(self, protocol=pickle.HIGHEST_PROTOCOL):
43+
self.protocol = protocol
44+
45+
def encode(self, buf):
46+
if hasattr(buf, 'dtype') and buf.dtype != 'object':
47+
raise ValueError("cannot encode non-object ndarrays, %s "
48+
"dtype was passed" % buf.dtype)
49+
return pickle.dumps(buf, protocol=self.protocol)
50+
51+
def decode(self, buf, out=None):
52+
dec = pickle.loads(buf)
53+
if out is not None:
54+
np.copyto(out, dec)
55+
return out
56+
else:
57+
return dec
58+
59+
def get_config(self):
60+
return dict(id=self.codec_id,
61+
protocol=self.protocol)
62+
63+
def __repr__(self):
64+
return 'Pickle(protocol=%s)' % self.protocol

numcodecs/tests/common.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
import numpy as np
8-
from nose.tools import eq_ as eq
8+
from nose.tools import eq_ as eq, assert_true
99
from numpy.testing import assert_array_almost_equal
1010

1111

@@ -91,6 +91,37 @@ def compare(res):
9191
compare(out)
9292

9393

94+
def check_encode_decode_objects(arr, codec):
95+
96+
# this is a more specific test that check_encode_decode
97+
# as these require actual objects (and not bytes only)
98+
99+
def compare(res, arr=arr):
100+
101+
assert_true(isinstance(res, np.ndarray))
102+
assert_true(res.shape == arr.shape)
103+
assert_true(res.dtype == 'object')
104+
105+
# numpy asserts don't compare object arrays
106+
# properly; assert that we have the same nans
107+
# and values
108+
arr = arr.ravel().tolist()
109+
res = res.ravel().tolist()
110+
for a, r in zip(arr, res):
111+
if a != a:
112+
assert_true(r != r)
113+
else:
114+
assert_true(a == r)
115+
116+
enc = codec.encode(arr)
117+
dec = codec.decode(enc)
118+
compare(dec)
119+
120+
out = np.empty_like(arr)
121+
codec.decode(enc, out=out)
122+
compare(out)
123+
124+
94125
def check_config(codec):
95126
config = codec.get_config()
96127
# round-trip through JSON to check serialization

numcodecs/tests/test_msgpacks.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import absolute_import, print_function, division
3+
4+
5+
import nose
6+
import numpy as np
7+
from numpy.testing import assert_raises
8+
9+
try:
10+
from numcodecs.msgpacks import MsgPack
11+
except ImportError:
12+
raise nose.SkipTest("no msgpack installed")
13+
14+
from numcodecs.tests.common import (check_config, check_repr,
15+
check_encode_decode_objects)
16+
17+
18+
# object array with strings
19+
# object array with mix strings / nans
20+
# object array with mix of string, int, float
21+
arrays = [
22+
np.array(['foo', 'bar', 'baz'] * 300, dtype=object),
23+
np.array([['foo', 'bar', np.nan]] * 300, dtype=object),
24+
np.array(['foo', 1.0, 2] * 300, dtype=object),
25+
]
26+
27+
# non-object ndarrays
28+
arrays_incompat = [
29+
np.arange(1000, dtype='i4'),
30+
np.array(['foo', 'bar', 'baz'] * 300),
31+
]
32+
33+
34+
def test_encode_errors():
35+
for arr in arrays_incompat:
36+
codec = MsgPack()
37+
assert_raises(ValueError, codec.encode, arr)
38+
39+
40+
def test_encode_decode():
41+
for arr in arrays:
42+
codec = MsgPack()
43+
check_encode_decode_objects(arr, codec)
44+
45+
46+
def test_config():
47+
codec = MsgPack()
48+
check_config(codec)
49+
50+
51+
def test_repr():
52+
check_repr("MsgPack()")

numcodecs/tests/test_pickle.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import absolute_import, print_function, division
3+
4+
5+
import numpy as np
6+
from numpy.testing import assert_raises
7+
8+
9+
from numcodecs.pickles import Pickle
10+
from numcodecs.tests.common import (check_config, check_repr,
11+
check_encode_decode_objects)
12+
13+
14+
# object array with strings
15+
# object array with mix strings / nans
16+
# object array with mix of string, int, float
17+
arrays = [
18+
np.array(['foo', 'bar', 'baz'] * 300, dtype=object),
19+
np.array([['foo', 'bar', np.nan]] * 300, dtype=object),
20+
np.array(['foo', 1.0, 2] * 300, dtype=object),
21+
]
22+
23+
# non-object ndarrays
24+
arrays_incompat = [
25+
np.arange(1000, dtype='i4'),
26+
np.array(['foo', 'bar', 'baz'] * 300),
27+
]
28+
29+
30+
def test_encode_errors():
31+
for arr in arrays_incompat:
32+
codec = Pickle()
33+
assert_raises(ValueError, codec.encode, arr)
34+
35+
36+
def test_encode_decode():
37+
for arr in arrays:
38+
codec = Pickle()
39+
check_encode_decode_objects(arr, codec)
40+
41+
42+
def test_config():
43+
codec = Pickle(protocol=-1)
44+
check_config(codec)
45+
46+
47+
def test_repr():
48+
check_repr("Pickle(protocol=-1)")

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
numpy
2+
msgpack-python

0 commit comments

Comments
 (0)