Skip to content

Commit f42ddfd

Browse files
authored
Merge pull request #2 from shoyer/indexing_broadcasting
Minor cleanup for Dataset.__getitem__
2 parents 7dd171d + 31401d4 commit f42ddfd

File tree

7 files changed

+428
-141
lines changed

7 files changed

+428
-141
lines changed

xarray/core/computation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,23 @@ def build_output_coords(
136136
signature, # type: _UFuncSignature
137137
exclude_dims=frozenset(), # type: set
138138
):
139+
"""Build output coordinates for an operation.
140+
141+
Parameters
142+
----------
143+
args : list
144+
List of raw operation arguments. Any valid types for xarray operations
145+
are OK, e.g., scalars, Variable, DataArray, Dataset.
146+
signature : _UfuncSignature
147+
Core dimensions signature for the operation.
148+
exclude_dims : optional set
149+
Dimensions excluded from the operation. Coordinates along these
150+
dimensions are dropped.
151+
152+
Returns
153+
-------
154+
OrderedDict of Variable objects with merged coordinates.
155+
"""
139156
# type: (...) -> List[OrderedDict[Any, Variable]]
140157
input_coords = _get_coord_variables(args)
141158

xarray/core/dataset.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,23 +1135,29 @@ def isel(self, drop=False, **indexers):
11351135
raise ValueError("dimensions %r do not exist" % invalid)
11361136

11371137
# extract new coordinates from indexers
1138-
variables = merge_variables([v._coords for _, v in
1139-
iteritems(indexers)
1140-
if isinstance(v, DataArray)],
1141-
compat='identical')
1142-
1138+
# we don't need to call align() explicitly, because merge_variables
1139+
# already checks for exact alignment between dimension coordinates
1140+
variables = merge_variables([v._coords for v in indexers.values()
1141+
if isinstance(v, DataArray)])
11431142
coord_names = set(self._coord_names) | set(variables)
11441143

1145-
# a tuple, e.g. (('x', ), [0, 1]), is converted to Variable
11461144
# all indexers should be int, slice, np.ndarrays, or Variable
1147-
indexers = [(k, Variable(dims=v[0], data=v[1]) if isinstance(v, tuple)
1148-
else v if isinstance(v, integer_types + (slice, Variable))
1149-
else v.variable if isinstance(v, DataArray)
1150-
else np.asarray(v))
1151-
for k, v in iteritems(indexers)]
1145+
indexers_list = []
1146+
for k, v in iteritems(indexers):
1147+
if isinstance(v, integer_types + (slice, Variable)):
1148+
pass
1149+
elif isinstance(v, DataArray):
1150+
v = v.variable
1151+
elif isinstance(v, tuple):
1152+
v = as_variable(v)
1153+
elif isinstance(v, Dataset):
1154+
raise TypeError('cannot use a Dataset as an indexer')
1155+
else:
1156+
v = np.asarray(v)
1157+
indexers_list.append((k, v))
11521158

11531159
for name, var in iteritems(self._variables):
1154-
var_indexers = dict((k, v) for k, v in indexers if k in var.dims)
1160+
var_indexers = {k: v for k, v in indexers_list if k in var.dims}
11551161
new_var = var.isel(**var_indexers)
11561162
if not (drop and name in var_indexers):
11571163
variables[name] = new_var

xarray/core/indexing.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pandas as pd
88

99
from . import utils
10+
from .npcompat import moveaxis
1011
from .pycompat import (iteritems, range, integer_types, dask_array_type,
1112
suppress)
1213
from .utils import is_dict_like
@@ -314,6 +315,10 @@ class VectorizedIndexer(IndexerTuple):
314315
""" Tuple for vectorized indexing """
315316

316317

318+
class PointwiseIndexer(IndexerTuple):
319+
""" Tuple for pointwise indexing with dask.array's vindex """
320+
321+
317322
class LazilyIndexedArray(utils.NDArrayMixin):
318323
"""Wrap an array that handles orthogonal indexing to make indexing lazy
319324
"""
@@ -477,10 +482,11 @@ def __init__(self, array):
477482
self.array = array
478483

479484
def __getitem__(self, key):
480-
if isinstance(key, VectorizedIndexer):
481-
# TODO should support vindex
482-
raise IndexError(
483-
'dask does not support vectorized indexing : {}'.format(key))
485+
# should always get PointwiseIndexer instead
486+
assert not isinstance(key, VectorizedIndexer)
487+
488+
if isinstance(key, PointwiseIndexer):
489+
return self._getitem_pointwise(key)
484490

485491
try:
486492
key = to_tuple(key)
@@ -492,6 +498,30 @@ def __getitem__(self, key):
492498
value = value[(slice(None),) * axis + (subkey,)]
493499
return value
494500

501+
def _getitem_pointwise(self, key):
502+
pointwise_shape, pointwise_index = next(
503+
(k.shape, i) for i, k in enumerate(key)
504+
if not isinstance(k, slice))
505+
# dask's indexing only handles 1d arrays
506+
flat_key = tuple(k if isinstance(k, slice) else k.ravel()
507+
for k in key)
508+
509+
if len([k for k in key if not isinstance(k, slice)]) == 1:
510+
# vindex requires more than one non-slice :(
511+
# but we can use normal indexing instead
512+
indexed = self.array[flat_key]
513+
new_shape = (indexed.shape[:pointwise_index] +
514+
pointwise_shape +
515+
indexed.shape[pointwise_index + 1:])
516+
return indexed.reshape(new_shape)
517+
else:
518+
indexed = self.array.vindex[flat_key]
519+
# vindex always moves slices to the end
520+
reshaped = indexed.reshape(pointwise_shape + indexed.shape[1:])
521+
# reorder dimensions to match order of appearance
522+
positions = np.arange(0, len(pointwise_shape))
523+
return moveaxis(reshaped, positions, positions + pointwise_index)
524+
495525
def __setitem__(self, key, value):
496526
raise TypeError("this variable's data is stored in a dask array, "
497527
'which does not support item assignment. To '

xarray/core/npcompat.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from __future__ import absolute_import
22
from __future__ import division
33
from __future__ import print_function
4+
5+
import operator
6+
47
import numpy as np
58

69
try:
7-
from numpy import broadcast_to, stack, nanprod, nancumsum, nancumprod
10+
from numpy import (broadcast_to, stack, nanprod, nancumsum, nancumprod,
11+
moveaxis)
812
except ImportError: # pragma: no cover
913
# Code copied from newer versions of NumPy (v1.10 to v1.12).
1014
# Used under the terms of NumPy's license, see licenses/NUMPY_LICENSE.
@@ -371,3 +375,130 @@ def nancumprod(a, axis=None, dtype=None, out=None):
371375
"""
372376
a, mask = _replace_nan(a, 1)
373377
return np.cumprod(a, axis=axis, dtype=dtype, out=out)
378+
379+
380+
def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
381+
"""
382+
Normalizes an axis argument into a tuple of non-negative integer axes.
383+
384+
This handles shorthands such as ``1`` and converts them to ``(1,)``,
385+
as well as performing the handling of negative indices covered by
386+
`normalize_axis_index`.
387+
388+
By default, this forbids axes from being specified multiple times.
389+
390+
Used internally by multi-axis-checking logic.
391+
392+
.. versionadded:: 1.13.0
393+
394+
Parameters
395+
----------
396+
axis : int, iterable of int
397+
The un-normalized index or indices of the axis.
398+
ndim : int
399+
The number of dimensions of the array that `axis` should be normalized
400+
against.
401+
argname : str, optional
402+
A prefix to put before the error message, typically the name of the
403+
argument.
404+
allow_duplicate : bool, optional
405+
If False, the default, disallow an axis from being specified twice.
406+
407+
Returns
408+
-------
409+
normalized_axes : tuple of int
410+
The normalized axis index, such that `0 <= normalized_axis < ndim`
411+
412+
Raises
413+
------
414+
AxisError
415+
If any axis provided is out of range
416+
ValueError
417+
If an axis is repeated
418+
419+
See also
420+
--------
421+
normalize_axis_index : normalizing a single scalar axis
422+
"""
423+
try:
424+
axis = [operator.index(axis)]
425+
except TypeError:
426+
axis = tuple(axis)
427+
axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis)
428+
if not allow_duplicate and len(set(axis)) != len(axis):
429+
if argname:
430+
raise ValueError('repeated axis in `{}` argument'.format(argname))
431+
else:
432+
raise ValueError('repeated axis')
433+
return axis
434+
435+
436+
def moveaxis(a, source, destination):
437+
"""
438+
Move axes of an array to new positions.
439+
440+
Other axes remain in their original order.
441+
442+
.. versionadded::1.11.0
443+
444+
Parameters
445+
----------
446+
a : np.ndarray
447+
The array whose axes should be reordered.
448+
source : int or sequence of int
449+
Original positions of the axes to move. These must be unique.
450+
destination : int or sequence of int
451+
Destination positions for each of the original axes. These must also be
452+
unique.
453+
454+
Returns
455+
-------
456+
result : np.ndarray
457+
Array with moved axes. This array is a view of the input array.
458+
459+
See Also
460+
--------
461+
transpose: Permute the dimensions of an array.
462+
swapaxes: Interchange two axes of an array.
463+
464+
Examples
465+
--------
466+
467+
>>> x = np.zeros((3, 4, 5))
468+
>>> np.moveaxis(x, 0, -1).shape
469+
(4, 5, 3)
470+
>>> np.moveaxis(x, -1, 0).shape
471+
(5, 3, 4)
472+
473+
These all achieve the same result:
474+
475+
>>> np.transpose(x).shape
476+
(5, 4, 3)
477+
>>> np.swapaxes(x, 0, -1).shape
478+
(5, 4, 3)
479+
>>> np.moveaxis(x, [0, 1], [-1, -2]).shape
480+
(5, 4, 3)
481+
>>> np.moveaxis(x, [0, 1, 2], [-1, -2, -3]).shape
482+
(5, 4, 3)
483+
484+
"""
485+
try:
486+
# allow duck-array types if they define transpose
487+
transpose = a.transpose
488+
except AttributeError:
489+
a = np.asarray(a)
490+
transpose = a.transpose
491+
492+
source = normalize_axis_tuple(source, a.ndim, 'source')
493+
destination = normalize_axis_tuple(destination, a.ndim, 'destination')
494+
if len(source) != len(destination):
495+
raise ValueError('`source` and `destination` arguments must have '
496+
'the same number of elements')
497+
498+
order = [n for n in range(a.ndim) if n not in source]
499+
500+
for dest, src in sorted(zip(destination, source)):
501+
order.insert(dest, src)
502+
503+
result = transpose(order)
504+
return result

xarray/core/variable.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .pycompat import (basestring, OrderedDict, zip, integer_types,
2020
dask_array_type)
2121
from .indexing import (PandasIndexAdapter, xarray_indexable, BasicIndexer,
22-
OuterIndexer, VectorizedIndexer)
22+
OuterIndexer, PointwiseIndexer, VectorizedIndexer)
2323

2424
import xarray as xr # only for Dataset and DataArray
2525

@@ -464,8 +464,11 @@ def _nonzero(self):
464464
in zip(nonzeros, self.dims))
465465

466466
def _broadcast_indexes_advanced(self, key):
467-
variables = []
467+
if isinstance(self._data, dask_array_type):
468+
# dask only supports a very restricted form of advanced indexing
469+
return self._broadcast_indexes_dask_pointwise(key)
468470

471+
variables = []
469472
for dim, value in zip(self.dims, key):
470473
if isinstance(value, slice):
471474
value = np.arange(*value.indices(self.sizes[dim]))
@@ -492,6 +495,46 @@ def _broadcast_indexes_advanced(self, key):
492495
key = VectorizedIndexer(variable.data for variable in variables)
493496
return dims, key
494497

498+
def _broadcast_indexes_dask_pointwise(self, key):
499+
if any(not isinstance(k, (Variable, slice)) for k in key):
500+
raise IndexError(
501+
'Vectorized indexing with dask requires that all indexers are '
502+
'labeled arrays or full slice objects: {}'.format(key))
503+
504+
if any(isinstance(k, Variable) and k.dtype.kind == 'b' for k in key):
505+
raise IndexError(
506+
'Vectorized indexing with dask does not support booleans: {}'
507+
.format(key))
508+
509+
dims_set = {k.dims for k in key if isinstance(k, Variable)}
510+
if len(dims_set) != 1:
511+
raise IndexError(
512+
'Vectorized indexing with dask requires that all labeled '
513+
'arrays in the indexer have the same dimension names, but '
514+
'arrays have different dimensions: {}'.format(key))
515+
(unique_dims,) = dims_set
516+
517+
shapes_set = {k.shape for k in key if isinstance(k, Variable)}
518+
if len(shapes_set) != 1:
519+
# matches message in _broadcast_indexes_advanced
520+
raise IndexError("Dimensions of indexers mismatch: {}".format(key))
521+
522+
dims = []
523+
found_first_array = False
524+
for k, d in zip(key, self.dims):
525+
if isinstance(k, slice):
526+
if d in unique_dims:
527+
raise IndexError(
528+
'Labeled arrays used in vectorized indexing with dask '
529+
'cannot reuse a sliced dimension: {}'.format(d))
530+
dims.append(d)
531+
elif not found_first_array:
532+
dims.extend(k.dims)
533+
found_first_array = True
534+
535+
key = PointwiseIndexer(getattr(k, 'data', k) for k in key)
536+
return tuple(dims), key
537+
495538
def __getitem__(self, key):
496539
"""Return a new Array object whose contents are consistent with
497540
getting the provided key from the underlying data.
@@ -507,12 +550,12 @@ def __getitem__(self, key):
507550
array `x.values` directly.
508551
"""
509552
dims, index_tuple = self._broadcast_indexes(key)
510-
values = self._indexable_data[index_tuple]
511-
if hasattr(values, 'ndim'):
512-
assert values.ndim == len(dims), (values.ndim, len(dims))
553+
data = self._indexable_data[index_tuple]
554+
if hasattr(data, 'ndim'):
555+
assert data.ndim == len(dims), (data.ndim, len(dims))
513556
else:
514557
assert len(dims) == 0, len(dims)
515-
return type(self)(dims, values, self._attrs, self._encoding,
558+
return type(self)(dims, data, self._attrs, self._encoding,
516559
fastpath=True)
517560

518561
def __setitem__(self, key, value):

xarray/tests/test_dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,8 @@ def test_isel_fancy(self):
881881
with self.assertRaisesRegexp(IndexError,
882882
'Dimensions of indexers mismatch'):
883883
data.isel(dim1=(('points'), [1, 2]), dim2=(('points'), [1, 2, 3]))
884+
with self.assertRaisesRegexp(TypeError, 'cannot use a Dataset'):
885+
data.isel(dim1=Dataset({'points': [1, 2]}))
884886

885887
# test to be sure we keep around variables that were not indexed
886888
ds = Dataset({'x': [1, 2, 3, 4], 'y': 0})

0 commit comments

Comments
 (0)