Skip to content

Initial hack to get dask distributed working #1083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

import numpy as np

try:
import distributed.protocol
HAS_DISTRIBUTED = True
except ImportError:
HAS_DISTRIBUTED = False

from .. import Variable
from ..conventions import pop_to, cf_encoder
from ..core import indexing
Expand Down Expand Up @@ -71,6 +77,29 @@ def __getitem__(self, key):
return data


def serialize_netcdf4_array_wrapper(array):
header, frames = distributed.protocol.serialize(array.array)
header['__netcdf4_array_wrapper__array_type'] = header.get('type')
header['__netcdf4_array_wrapper__is_remote'] = array.is_remote
return header, frames


def deserialize_netcdf4_array_wrapper(header, frames):
is_remote = header.pop('__netcdf4_array_wrapper__is_remote')
type_ = header.pop('__netcdf4_array_wrapper__array_type', None)
if type_ is not None:
header['type'] = type_
array = distributed.protocol.deserialize(header, frames)
return NetCDF4ArrayWrapper(array, is_remote)


if HAS_DISTRIBUTED:
distributed.protocol.register_serialization(
NetCDF4ArrayWrapper,
serialize_netcdf4_array_wrapper,
deserialize_netcdf4_array_wrapper)


def _nc4_values_and_dtype(var):
if var.dtype.kind == 'U':
# this entire clause should not be necessary with netCDF4>=1.0.9
Expand Down
27 changes: 27 additions & 0 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from collections import defaultdict
from pandas.tslib import OutOfBoundsDatetime

try:
import distributed.protocol
HAS_DISTRIBUTED = True
except ImportError:
HAS_DISTRIBUTED = False

from .core import indexing, ops, utils
from .core.formatting import format_timestamp, first_n_items, last_item
from .core.variable import as_variable, Variable
Expand Down Expand Up @@ -470,6 +476,27 @@ def __getitem__(self, key):
return values


def serialize_char_to_string_array(array):
header, frames = distributed.protocol.serialize(array.array)
header['__char_to_string_array__array_type'] = header.get('type')
return header, frames


def deserialize_char_to_string_array(header, frames):
type_ = header.pop('__char_to_string_array__array_type', None)
if type_ is not None:
header['type'] = type_
array = distributed.protocol.deserialize(header, frames)
return CharToStringArray(array)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any objection if I change this (and the other similar functions) to use the sub-types list solution?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be fine, but these classes (other than LazilyIndexedArray) are not going to be layered multiple times. So I wouldn't worry too much about that until we settle on a general solution.



if HAS_DISTRIBUTED:
distributed.protocol.register_serialization(
CharToStringArray,
serialize_char_to_string_array,
deserialize_char_to_string_array)


class NativeEndiannessArray(utils.NDArrayMixin):
"""Decode arrays on the fly from non-native to native endianness

Expand Down
31 changes: 31 additions & 0 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import numpy as np
import pandas as pd

try:
import distributed.protocol
HAS_DISTRIBUTED = True
except ImportError:
HAS_DISTRIBUTED = False

from . import utils
from .pycompat import iteritems, range, dask_array_type, suppress
from .utils import is_full_slice, is_dict_like
Expand Down Expand Up @@ -400,6 +406,31 @@ def __repr__(self):
(type(self).__name__, self.array, self.key))


def serialize_lazily_indexed_array(array):
header, frames = distributed.protocol.serialize(array.array)
if 'sub-type' not in header:
header['sub-type'] = []
header['sub-type'] = list(header['sub-type']) + [header.pop('type')]
frames.append(distributed.protocol.pickle.dumps(array.key))
return header, frames


def deserialize_lazily_indexed_array(header, frames):
key = distributed.protocol.pickle.loads(frames[-1])
if header.get('sub-type'):
header['type'] = header['sub-type'][-1]
header['sub-type'] = header['sub-type'][:-1]
array = distributed.protocol.deserialize(header, frames[:-1])
return LazilyIndexedArray(array, key)


if HAS_DISTRIBUTED:
distributed.protocol.register_serialization(
LazilyIndexedArray,
serialize_lazily_indexed_array,
deserialize_lazily_indexed_array)


def orthogonally_indexable(array):
if isinstance(array, np.ndarray):
return NumpyIndexingAdapter(array)
Expand Down
74 changes: 74 additions & 0 deletions xarray/test/test_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
import pytest
import xarray as xr

distributed = pytest.importorskip('distributed')
da = pytest.importorskip('dask.array')
from distributed.protocol import serialize, deserialize
from distributed.utils_test import cluster, loop, gen_cluster

from xarray.core.indexing import LazilyIndexedArray
from xarray.backends.netCDF4_ import NetCDF4ArrayWrapper
from xarray.conventions import CharToStringArray

from xarray.test.test_backends import create_tmp_file


def test_serialize_deserialize_lazily_indexed_array():
original = LazilyIndexedArray(np.arange(10))[:5]
restored = deserialize(*serialize(original))
assert type(restored) is type(original)
assert (restored.array == original.array).all()
assert restored.key == original.key


def test_serialize_deserialize_netcdf4_array_wrapper():
original = NetCDF4ArrayWrapper(np.arange(10), is_remote=False)
restored = deserialize(*serialize(original))
assert type(restored) is type(original)
assert (restored.array == original.array).all()
assert restored.is_remote == original.is_remote


def test_serialize_deserialize_char_to_string_array():
original = CharToStringArray(np.array(['a', 'b', 'c'], dtype='S1'))
restored = deserialize(*serialize(original))
assert type(restored) is type(original)
assert (restored.array == original.array).all()


def test_serialize_deserialize_nested_arrays():
original = LazilyIndexedArray(NetCDF4ArrayWrapper(np.arange(5)))
restored = deserialize(*serialize(original))
assert (restored.array.array == original.array.array).all()


def test_dask_distributed_integration_test(loop):
with cluster() as (s, _):
with distributed.Client(('127.0.0.1', s['port']), loop=loop) as client:
original = xr.Dataset({'foo': ('x', [10, 20, 30, 40, 50])})
with create_tmp_file() as filename:
original.to_netcdf(filename, engine='netcdf4')
# TODO: should be able to serialize locks?
# TODO: should be able to serialize array types from
# xarray.conventions
restored = xr.open_dataset(filename, chunks=3, lock=False)
assert isinstance(restored.foo.data, da.Array)
restored.load()
assert original.identical(restored)


@gen_cluster(client=True)
def test_dask_distributed_integration_test_fast(c, s, a, b):
original = xr.Dataset({'foo': ('x', [10, 20, 30, 40, 50])})
with create_tmp_file() as filename:
original.to_netcdf(filename, engine='netcdf4')
# TODO: should be able to serialize locks?
# TODO: should be able to serialize array types from
# xarray.conventions
restored = xr.open_dataset(filename, chunks=3, lock=False, decode_cf=False)
print(restored.foo.data.dask)
y = c.compute(restored.foo.data)
y = yield y._result()
computed = xr.Dataset({'foo': ('x', y)})
assert computed.identical(original)