Skip to content

Commit 53c440d

Browse files
authored
Merge pull request PaddlePaddle#77 from jacobbieker/master
Add autencoding of numpy arrays
2 parents e5c895a + 9033c06 commit 53c440d

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

webdataset/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
ignore_and_stop,
3030
warn_and_stop,
3131
)
32-
from .writer import ShardWriter, TarWriter, torch_dumps
32+
from .writer import ShardWriter, TarWriter, torch_dumps, numpy_dumps
3333
from .autodecode import (
3434
Continue,
3535
handle_extension,

webdataset/writer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,19 @@ def torch_dumps(data):
7878
return stream.getvalue()
7979

8080

81+
def numpy_dumps(data):
82+
"""Dump data into a bytestring using numpy npy format
83+
84+
:param data: data to be dumped
85+
"""
86+
import io
87+
import numpy.lib.format
88+
89+
stream = io.BytesIO()
90+
numpy.lib.format.write_array(stream, data)
91+
return stream.getvalue()
92+
93+
8194
def make_handlers():
8295
"""Create a list of handlers for encoding data."""
8396
handlers = {}
@@ -99,6 +112,8 @@ def f(extension_):
99112
handlers[extension] = pickle.dumps
100113
for extension in ["pth"]:
101114
handlers[extension] = torch_dumps
115+
for extension in ["npy"]:
116+
handlers[extension] = numpy_dumps
102117
for extension in ["json", "jsn"]:
103118
handlers[extension] = lambda x: json.dumps(x).encode("utf-8")
104119
for extension in ["ten", "tb"]:

0 commit comments

Comments
 (0)