From fa8de2c5ded8318ba4c18bc99265a67e123119d2 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 6 Nov 2016 12:57:51 -0800 Subject: [PATCH 1/2] Initial hack to get dask distributed working --- xarray/backends/netCDF4_.py | 29 +++++++++++++++ xarray/core/indexing.py | 31 ++++++++++++++++ xarray/test/test_distributed.py | 66 +++++++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+) create mode 100644 xarray/test/test_distributed.py diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 1bf38e4325f..b6b85251a66 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -3,6 +3,12 @@ import numpy as np +try: + import distributed.protocol + HAS_DISTRIBUTED = True +except ImportError: + HAS_DISTRIBUTED = False + from .. import Variable from ..conventions import pop_to, cf_encoder from ..core import indexing @@ -71,6 +77,29 @@ def __getitem__(self, key): return data +def serialize_netcdf4_array_wrapper(array): + header, frames = distributed.protocol.serialize(array.array) + header['__netcdf4_array_wrapper__array_type'] = header.get('type') + header['__netcdf4_array_wrapper__is_remote'] = array.is_remote + return header, frames + + +def deserialize_netcdf4_array_wrapper(header, frames): + is_remote = header.pop('__netcdf4_array_wrapper__is_remote') + type_ = header.pop('__netcdf4_array_wrapper__array_type', None) + if type_ is not None: + header['type'] = type_ + array = distributed.protocol.deserialize(header, frames) + return NetCDF4ArrayWrapper(array, is_remote) + + +if HAS_DISTRIBUTED: + distributed.protocol.register_serialization( + NetCDF4ArrayWrapper, + serialize_netcdf4_array_wrapper, + deserialize_netcdf4_array_wrapper) + + def _nc4_values_and_dtype(var): if var.dtype.kind == 'U': # this entire clause should not be necessary with netCDF4>=1.0.9 diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index bfdc6d305ad..2129c2ff3bb 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -3,6 +3,12 @@ import numpy as np import pandas as pd +try: + import distributed.protocol + HAS_DISTRIBUTED = True +except ImportError: + HAS_DISTRIBUTED = False + from . import utils from .pycompat import iteritems, range, dask_array_type, suppress from .utils import is_full_slice, is_dict_like @@ -400,6 +406,31 @@ def __repr__(self): (type(self).__name__, self.array, self.key)) +def serialize_lazily_indexed_array(array): + header, frames = distributed.protocol.serialize(array.array) + if 'sub-type' not in header: + header['sub-type'] = [] + header['sub-type'] = list(header['sub-type']) + [header.pop('type')] + frames.append(distributed.protocol.pickle.dumps(array.key)) + return header, frames + + +def deserialize_lazily_indexed_array(header, frames): + key = distributed.protocol.pickle.loads(frames[-1]) + if header.get('sub-type'): + header['type'] = header['sub-type'][-1] + header['sub-type'] = header['sub-type'][:-1] + array = distributed.protocol.deserialize(header, frames[:-1]) + return LazilyIndexedArray(array, key) + + +if HAS_DISTRIBUTED: + distributed.protocol.register_serialization( + LazilyIndexedArray, + serialize_lazily_indexed_array, + deserialize_lazily_indexed_array) + + def orthogonally_indexable(array): if isinstance(array, np.ndarray): return NumpyIndexingAdapter(array) diff --git a/xarray/test/test_distributed.py b/xarray/test/test_distributed.py new file mode 100644 index 00000000000..2311ba77f0b --- /dev/null +++ b/xarray/test/test_distributed.py @@ -0,0 +1,66 @@ +import numpy as np +import pytest +import xarray as xr + +distributed = pytest.importorskip('distributed') +da = pytest.importorskip('dask.array') +from distributed.protocol import serialize, deserialize +from distributed.utils_test import cluster, loop, gen_cluster + +from xarray.core.indexing import LazilyIndexedArray +from xarray.backends.netCDF4_ import NetCDF4ArrayWrapper + +from xarray.test.test_backends import create_tmp_file + + +def test_serialize_deserialize_lazily_indexed_array(): + original = LazilyIndexedArray(np.arange(10))[:5] + restored = deserialize(*serialize(original)) + assert type(restored) is type(original) + assert (restored.array == original.array).all() + assert restored.key == original.key + + +def test_serialize_deserialize_netcdf4_array_wrapper(): + original = NetCDF4ArrayWrapper(np.arange(10), is_remote=False) + restored = deserialize(*serialize(original)) + assert type(restored) is type(original) + assert (restored.array == original.array).all() + assert restored.is_remote == original.is_remote + + +def test_serialize_deserialize_nested_arrays(): + original = LazilyIndexedArray(NetCDF4ArrayWrapper(np.arange(5))) + restored = deserialize(*serialize(original)) + assert (restored.array.array == original.array.array).all() + + +def test_dask_distributed_integration_test(loop): + with cluster() as (s, _): + with distributed.Client(('127.0.0.1', s['port']), loop=loop) as client: + original = xr.Dataset({'foo': ('x', [10, 20, 30, 40, 50])}) + with create_tmp_file() as filename: + original.to_netcdf(filename, engine='netcdf4') + # TODO: should be able to serialize locks? + # TODO: should be able to serialize array types from + # xarray.conventions + restored = xr.open_dataset(filename, chunks=3, lock=False) + assert isinstance(restored.foo.data, da.Array) + restored.load() + assert original.identical(restored) + + +@gen_cluster(client=True) +def test_dask_distributed_integration_test_fast(c, s, a, b): + original = xr.Dataset({'foo': ('x', [10, 20, 30, 40, 50])}) + with create_tmp_file() as filename: + original.to_netcdf(filename, engine='netcdf4') + # TODO: should be able to serialize locks? + # TODO: should be able to serialize array types from + # xarray.conventions + restored = xr.open_dataset(filename, chunks=3, lock=False, decode_cf=False) + print(restored.foo.data.dask) + y = c.compute(restored.foo.data) + y = yield y._result() + computed = xr.Dataset({'foo': ('x', y)}) + assert computed.identical(original) From 75385ace9e0625cf7b20f1d171fa3cb38e858d43 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 6 Nov 2016 21:28:25 -0800 Subject: [PATCH 2/2] Add char_to_string_array serialization --- xarray/conventions.py | 27 +++++++++++++++++++++++++++ xarray/test/test_distributed.py | 8 ++++++++ 2 files changed, 35 insertions(+) diff --git a/xarray/conventions.py b/xarray/conventions.py index b26e42c9fd8..f7be6bd6409 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -8,6 +8,12 @@ from collections import defaultdict from pandas.tslib import OutOfBoundsDatetime +try: + import distributed.protocol + HAS_DISTRIBUTED = True +except ImportError: + HAS_DISTRIBUTED = False + from .core import indexing, ops, utils from .core.formatting import format_timestamp, first_n_items, last_item from .core.variable import as_variable, Variable @@ -470,6 +476,27 @@ def __getitem__(self, key): return values +def serialize_char_to_string_array(array): + header, frames = distributed.protocol.serialize(array.array) + header['__char_to_string_array__array_type'] = header.get('type') + return header, frames + + +def deserialize_char_to_string_array(header, frames): + type_ = header.pop('__char_to_string_array__array_type', None) + if type_ is not None: + header['type'] = type_ + array = distributed.protocol.deserialize(header, frames) + return CharToStringArray(array) + + +if HAS_DISTRIBUTED: + distributed.protocol.register_serialization( + CharToStringArray, + serialize_char_to_string_array, + deserialize_char_to_string_array) + + class NativeEndiannessArray(utils.NDArrayMixin): """Decode arrays on the fly from non-native to native endianness diff --git a/xarray/test/test_distributed.py b/xarray/test/test_distributed.py index 2311ba77f0b..fd86acae142 100644 --- a/xarray/test/test_distributed.py +++ b/xarray/test/test_distributed.py @@ -9,6 +9,7 @@ from xarray.core.indexing import LazilyIndexedArray from xarray.backends.netCDF4_ import NetCDF4ArrayWrapper +from xarray.conventions import CharToStringArray from xarray.test.test_backends import create_tmp_file @@ -29,6 +30,13 @@ def test_serialize_deserialize_netcdf4_array_wrapper(): assert restored.is_remote == original.is_remote +def test_serialize_deserialize_char_to_string_array(): + original = CharToStringArray(np.array(['a', 'b', 'c'], dtype='S1')) + restored = deserialize(*serialize(original)) + assert type(restored) is type(original) + assert (restored.array == original.array).all() + + def test_serialize_deserialize_nested_arrays(): original = LazilyIndexedArray(NetCDF4ArrayWrapper(np.arange(5))) restored = deserialize(*serialize(original))