Skip to content

Commit 9fbd15d

Browse files
committed
Merge pull request #321 from xray/auto-align
Automatic label-based alignment for math and Dataset constructor
2 parents 2eb0d96 + 6ccf629 commit 9fbd15d

File tree

6 files changed

+334
-109
lines changed

6 files changed

+334
-109
lines changed

doc/whats-new.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@ What's New
99
import xray
1010
np.random.seed(123456)
1111
12+
v0.4 (unreleased)
13+
-----------------
14+
15+
Highlights
16+
~~~~~~~~~~
17+
18+
- Automatic alignment of index labels in arithmetic, dataset cosntruction and
19+
merging.
20+
- Aggregation operations skip missing values by default.
21+
- Lots of bug fixes.
22+
1223
v0.3.2 (23 December, 2014)
1324
--------------------------
1425

xray/core/alignment.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,47 @@
1212
from .variable import as_variable, Variable, Coordinate, broadcast_variables
1313

1414

15-
def _get_all_indexes(objects):
15+
def _get_joiner(join):
16+
if join == 'outer':
17+
return functools.partial(functools.reduce, operator.or_)
18+
elif join == 'inner':
19+
return functools.partial(functools.reduce, operator.and_)
20+
elif join == 'left':
21+
return operator.itemgetter(0)
22+
elif join == 'right':
23+
return operator.itemgetter(-1)
24+
else:
25+
raise ValueError('invalid value for join: %s' % join)
26+
27+
28+
def _get_all_indexes(objects, exclude=set()):
1629
all_indexes = defaultdict(list)
1730
for obj in objects:
1831
for k, v in iteritems(obj.indexes):
19-
all_indexes[k].append(v)
32+
if k not in exclude:
33+
all_indexes[k].append(v)
2034
return all_indexes
2135

2236

37+
def _join_indexes(join, objects, exclude=set()):
38+
joiner = _get_joiner(join)
39+
indexes = _get_all_indexes(objects, exclude=exclude)
40+
# exclude dimensions with all equal indices (the usual case) to avoid
41+
# unnecessary reindexing work.
42+
# TODO: don't bother to check equals for left or right joins
43+
joined_indexes = dict((k, joiner(v)) for k, v in iteritems(indexes)
44+
if any(not v[0].equals(idx) for idx in v[1:]))
45+
return joined_indexes
46+
47+
2348
def align(*objects, **kwargs):
2449
"""align(*objects, join='inner', copy=True)
2550
2651
Given any number of Dataset and/or DataArray objects, returns new
2752
objects with aligned indexes.
2853
2954
Array from the aligned objects are suitable as input to mathematical
30-
operators, because along each dimension they are indexed by the same
31-
indexes.
55+
operators, because along each dimension they have the same indexes.
3256
3357
Missing values (if ``join != 'inner'``) are filled with NaN.
3458
@@ -39,13 +63,13 @@ def align(*objects, **kwargs):
3963
join : {'outer', 'inner', 'left', 'right'}, optional
4064
Method for joining the indexes of the passed objects along each
4165
dimension:
42-
- 'outer': use the union of object indexes
43-
- 'inner': use the intersection of object indexes
44-
- 'left': use indexes from the first object with each dimension
45-
- 'right': use indexes from the last object with each dimension
66+
- 'outer': use the union of object indexes
67+
- 'inner': use the intersection of object indexes
68+
- 'left': use indexes from the first object with each dimension
69+
- 'right': use indexes from the last object with each dimension
4670
copy : bool, optional
47-
If `copy=True`, the returned objects contain all new variables. If
48-
`copy=False` and no reindexing is required then the aligned objects
71+
If ``copy=True``, the returned objects contain all new variables. If
72+
``copy=False`` and no reindexing is required then the aligned objects
4973
will include original variables.
5074
5175
Returns
@@ -55,23 +79,27 @@ def align(*objects, **kwargs):
5579
"""
5680
join = kwargs.pop('join', 'inner')
5781
copy = kwargs.pop('copy', True)
82+
if kwargs:
83+
raise TypeError('align() got unexpected keyword arguments: %s'
84+
% list(kwargs))
5885

59-
if join == 'outer':
60-
join_indices = functools.partial(functools.reduce, operator.or_)
61-
elif join == 'inner':
62-
join_indices = functools.partial(functools.reduce, operator.and_)
63-
elif join == 'left':
64-
join_indices = operator.itemgetter(0)
65-
elif join == 'right':
66-
join_indices = operator.itemgetter(-1)
86+
joined_indexes = _join_indexes(join, objects)
87+
return tuple(obj.reindex(copy=copy, **joined_indexes) for obj in objects)
6788

68-
all_indexes = _get_all_indexes(objects)
6989

70-
# Exclude dimensions with all equal indices to avoid unnecessary reindexing
71-
# work.
72-
joined_indexes = dict((k, join_indices(v)) for k, v in iteritems(all_indexes)
73-
if any(not v[0].equals(idx) for idx in v[1:]))
90+
def partial_align(*objects, **kwargs):
91+
"""partial_align(*objects, join='inner', copy=True, exclude=set()
92+
93+
Like align, but don't align along dimensions in exclude. Not public API.
94+
"""
95+
join = kwargs.pop('join', 'inner')
96+
copy = kwargs.pop('copy', True)
97+
exclude = kwargs.pop('exclude', set())
98+
if kwargs:
99+
raise TypeError('align() got unexpected keyword arguments: %s'
100+
% list(kwargs))
74101

102+
joined_indexes = _join_indexes(join, objects, exclude=exclude)
75103
return tuple(obj.reindex(copy=copy, **joined_indexes) for obj in objects)
76104

77105

xray/core/dataarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from . import ops
1010
from . import utils
1111
from . import variable
12+
from .alignment import align
1213
from .common import AbstractArray, AttrAccessMixin
1314
from .coordinates import DataArrayCoordinates, Indexes
1415
from .dataset import Dataset
@@ -950,6 +951,13 @@ def _binary_op(f, reflexive=False):
950951
def func(self, other):
951952
if isinstance(other, (Dataset, groupby.GroupBy)):
952953
return NotImplemented
954+
if hasattr(other, 'indexes'):
955+
self, other = align(self, other, join='inner', copy=False)
956+
empty_indexes = [d for d, s in zip(self.dims, self.shape)
957+
if s == 0]
958+
if empty_indexes:
959+
raise ValueError('no overlapping labels for some '
960+
'dimensions: %s' % empty_indexes)
953961
other_coords = getattr(other, 'coords', None)
954962
other_variable = getattr(other, 'variable', other)
955963
ds = self.coords.merge(other_coords)

0 commit comments

Comments
 (0)