Skip to content

Commit 0811141

Browse files
phausamannshoyer
authored andcommitted
Add transpose_coords option to DataArray.transpose (#2556)
* Add transpose_coords option to DataArray.transpose Fixes #1856 * Fix typo * Fix bug in transpose Fix python 2 compatibility * Set default for transpose_coords to None Update documentation * Fix bug in coordinate tranpose Update documentation * Suppress FutureWarning in tests * Add restore_coord_dims parameter to DataArrayGroupBy.apply * Move restore_coord_dims parameter to GroupBy class * Remove restore_coord_dims parameter from DataArrayResample.apply * Update whats-new * Update whats-new
1 parent 6658108 commit 0811141

File tree

9 files changed

+144
-33
lines changed

9 files changed

+144
-33
lines changed

doc/whats-new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ Enhancements
2727
- Character arrays' character dimension name decoding and encoding handled by
2828
``var.encoding['char_dim_name']`` (:issue:`2895`)
2929
By `James McCreight <https://github.com/jmccreight>`_.
30+
- :py:meth:`DataArray.transpose` now accepts a keyword argument
31+
``transpose_coords`` which enables transposition of coordinates in the
32+
same way as :py:meth:`Dataset.transpose`. :py:meth:`DataArray.groupby`
33+
:py:meth:`DataArray.groupby_bins`, and :py:meth:`DataArray.resample` now
34+
accept a keyword argument ``restore_coord_dims`` which keeps the order
35+
of the dimensions of multi-dimensional coordinates intact (:issue:`1856`).
36+
By `Peter Hausamann <http://github.com/phausamann>`_.
3037
- Clean up Python 2 compatibility in code (:issue:`2950`)
3138
By `Guido Imperiale <https://github.com/crusaderky>`_.
3239
- Implement ``load_dataset()`` and ``load_dataarray()`` as alternatives to

xarray/core/common.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,8 @@ def pipe(self, func: Union[Callable[..., T], Tuple[Callable[..., T], str]],
441441
else:
442442
return func(self, *args, **kwargs)
443443

444-
def groupby(self, group, squeeze: bool = True):
444+
def groupby(self, group, squeeze: bool = True,
445+
restore_coord_dims: Optional[bool] = None):
445446
"""Returns a GroupBy object for performing grouped operations.
446447
447448
Parameters
@@ -453,6 +454,9 @@ def groupby(self, group, squeeze: bool = True):
453454
If "group" is a dimension of any arrays in this dataset, `squeeze`
454455
controls whether the subarrays have a dimension of length 1 along
455456
that dimension or if the dimension is squeezed out.
457+
restore_coord_dims : bool, optional
458+
If True, also restore the dimension order of multi-dimensional
459+
coordinates.
456460
457461
Returns
458462
-------
@@ -485,11 +489,13 @@ def groupby(self, group, squeeze: bool = True):
485489
core.groupby.DataArrayGroupBy
486490
core.groupby.DatasetGroupBy
487491
""" # noqa
488-
return self._groupby_cls(self, group, squeeze=squeeze)
492+
return self._groupby_cls(self, group, squeeze=squeeze,
493+
restore_coord_dims=restore_coord_dims)
489494

490495
def groupby_bins(self, group, bins, right: bool = True, labels=None,
491496
precision: int = 3, include_lowest: bool = False,
492-
squeeze: bool = True):
497+
squeeze: bool = True,
498+
restore_coord_dims: Optional[bool] = None):
493499
"""Returns a GroupBy object for performing grouped operations.
494500
495501
Rather than using all unique values of `group`, the values are discretized
@@ -522,6 +528,9 @@ def groupby_bins(self, group, bins, right: bool = True, labels=None,
522528
If "group" is a dimension of any arrays in this dataset, `squeeze`
523529
controls whether the subarrays have a dimension of length 1 along
524530
that dimension or if the dimension is squeezed out.
531+
restore_coord_dims : bool, optional
532+
If True, also restore the dimension order of multi-dimensional
533+
coordinates.
525534
526535
Returns
527536
-------
@@ -536,9 +545,11 @@ def groupby_bins(self, group, bins, right: bool = True, labels=None,
536545
.. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html
537546
""" # noqa
538547
return self._groupby_cls(self, group, squeeze=squeeze, bins=bins,
548+
restore_coord_dims=restore_coord_dims,
539549
cut_kwargs={'right': right, 'labels': labels,
540550
'precision': precision,
541-
'include_lowest': include_lowest})
551+
'include_lowest':
552+
include_lowest})
542553

543554
def rolling(self, dim: Optional[Mapping[Hashable, int]] = None,
544555
min_periods: Optional[int] = None, center: bool = False,
@@ -669,7 +680,7 @@ def resample(self, indexer: Optional[Mapping[Hashable, str]] = None,
669680
skipna=None, closed: Optional[str] = None,
670681
label: Optional[str] = None,
671682
base: int = 0, keep_attrs: Optional[bool] = None,
672-
loffset=None,
683+
loffset=None, restore_coord_dims: Optional[bool] = None,
673684
**indexer_kwargs: str):
674685
"""Returns a Resample object for performing resampling operations.
675686
@@ -697,6 +708,9 @@ def resample(self, indexer: Optional[Mapping[Hashable, str]] = None,
697708
If True, the object's attributes (`attrs`) will be copied from
698709
the original object to the new one. If False (default), the new
699710
object will be returned without attributes.
711+
restore_coord_dims : bool, optional
712+
If True, also restore the dimension order of multi-dimensional
713+
coordinates.
700714
**indexer_kwargs : {dim: freq}
701715
The keyword arguments form of ``indexer``.
702716
One of indexer or indexer_kwargs must be provided.
@@ -786,7 +800,8 @@ def resample(self, indexer: Optional[Mapping[Hashable, str]] = None,
786800
dims=dim_coord.dims, name=RESAMPLE_DIM)
787801
resampler = self._resample_cls(self, group=group, dim=dim_name,
788802
grouper=grouper,
789-
resample_dim=RESAMPLE_DIM)
803+
resample_dim=RESAMPLE_DIM,
804+
restore_coord_dims=restore_coord_dims)
790805

791806
return resampler
792807

xarray/core/dataarray.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,14 +1405,16 @@ def unstack(self, dim=None):
14051405
ds = self._to_temp_dataset().unstack(dim)
14061406
return self._from_temp_dataset(ds)
14071407

1408-
def transpose(self, *dims) -> 'DataArray':
1408+
def transpose(self, *dims, transpose_coords=None) -> 'DataArray':
14091409
"""Return a new DataArray object with transposed dimensions.
14101410
14111411
Parameters
14121412
----------
14131413
*dims : str, optional
14141414
By default, reverse the dimensions. Otherwise, reorder the
14151415
dimensions to this order.
1416+
transpose_coords : boolean, optional
1417+
If True, also transpose the coordinates of this DataArray.
14161418
14171419
Returns
14181420
-------
@@ -1430,8 +1432,28 @@ def transpose(self, *dims) -> 'DataArray':
14301432
numpy.transpose
14311433
Dataset.transpose
14321434
"""
1435+
if dims:
1436+
if set(dims) ^ set(self.dims):
1437+
raise ValueError('arguments to transpose (%s) must be '
1438+
'permuted array dimensions (%s)'
1439+
% (dims, tuple(self.dims)))
1440+
14331441
variable = self.variable.transpose(*dims)
1434-
return self._replace(variable)
1442+
if transpose_coords:
1443+
coords = {}
1444+
for name, coord in self.coords.items():
1445+
coord_dims = tuple(dim for dim in dims if dim in coord.dims)
1446+
coords[name] = coord.variable.transpose(*coord_dims)
1447+
return self._replace(variable, coords)
1448+
else:
1449+
if transpose_coords is None \
1450+
and any(self[c].ndim > 1 for c in self.coords):
1451+
warnings.warn('This DataArray contains multi-dimensional '
1452+
'coordinates. In the future, these coordinates '
1453+
'will be transposed as well unless you specify '
1454+
'transpose_coords=False.',
1455+
FutureWarning, stacklevel=2)
1456+
return self._replace(variable)
14351457

14361458
@property
14371459
def T(self) -> 'DataArray':

xarray/core/groupby.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ class GroupBy(SupportsArithmetic):
197197
"""
198198

199199
def __init__(self, obj, group, squeeze=False, grouper=None, bins=None,
200-
cut_kwargs={}):
200+
restore_coord_dims=None, cut_kwargs={}):
201201
"""Create a GroupBy object
202202
203203
Parameters
@@ -215,6 +215,9 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None,
215215
bins : array-like, optional
216216
If `bins` is specified, the groups will be discretized into the
217217
specified bins by `pandas.cut`.
218+
restore_coord_dims : bool, optional
219+
If True, also restore the dimension order of multi-dimensional
220+
coordinates.
218221
cut_kwargs : dict, optional
219222
Extra keyword arguments to pass to `pandas.cut`
220223
@@ -279,6 +282,16 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None,
279282
safe_cast_to_index(group), sort=(bins is None))
280283
unique_coord = IndexVariable(group.name, unique_values)
281284

285+
if isinstance(obj, DataArray) \
286+
and restore_coord_dims is None \
287+
and any(obj[c].ndim > 1 for c in obj.coords):
288+
warnings.warn('This DataArray contains multi-dimensional '
289+
'coordinates. In the future, the dimension order '
290+
'of these coordinates will be restored as well '
291+
'unless you specify restore_coord_dims=False.',
292+
FutureWarning, stacklevel=2)
293+
restore_coord_dims = False
294+
282295
# specification for the groupby operation
283296
self._obj = obj
284297
self._group = group
@@ -288,6 +301,7 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None,
288301
self._stacked_dim = stacked_dim
289302
self._inserted_dims = inserted_dims
290303
self._full_index = full_index
304+
self._restore_coord_dims = restore_coord_dims
291305

292306
# cached attributes
293307
self._groups = None
@@ -508,7 +522,8 @@ def lookup_order(dimension):
508522
return axis
509523

510524
new_order = sorted(stacked.dims, key=lookup_order)
511-
return stacked.transpose(*new_order)
525+
return stacked.transpose(
526+
*new_order, transpose_coords=self._restore_coord_dims)
512527

513528
def apply(self, func, shortcut=False, args=(), **kwargs):
514529
"""Apply a function over each array in the group and concatenate them
@@ -558,7 +573,7 @@ def apply(self, func, shortcut=False, args=(), **kwargs):
558573
for arr in grouped)
559574
return self._combine(applied, shortcut=shortcut)
560575

561-
def _combine(self, applied, shortcut=False):
576+
def _combine(self, applied, restore_coord_dims=False, shortcut=False):
562577
"""Recombine the applied objects like the original."""
563578
applied_example, applied = peek_at(applied)
564579
coord, dim, positions = self._infer_concat_args(applied_example)
@@ -580,8 +595,8 @@ def _combine(self, applied, shortcut=False):
580595
combined = self._maybe_unstack(combined)
581596
return combined
582597

583-
def reduce(self, func, dim=None, axis=None,
584-
keep_attrs=None, shortcut=True, **kwargs):
598+
def reduce(self, func, dim=None, axis=None, keep_attrs=None,
599+
shortcut=True, **kwargs):
585600
"""Reduce the items in this group by applying `func` along some
586601
dimension(s).
587602

xarray/plot/plot.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ def _infer_line_data(darray, x, y, hue):
6464
if huename in darray.dims:
6565
otherindex = 1 if darray.dims.index(huename) == 0 else 0
6666
otherdim = darray.dims[otherindex]
67-
yplt = darray.transpose(otherdim, huename)
68-
xplt = xplt.transpose(otherdim, huename)
67+
yplt = darray.transpose(
68+
otherdim, huename, transpose_coords=False)
69+
xplt = xplt.transpose(
70+
otherdim, huename, transpose_coords=False)
6971
else:
7072
raise ValueError('For 2D inputs, hue must be a dimension'
7173
+ ' i.e. one of ' + repr(darray.dims))
@@ -79,7 +81,9 @@ def _infer_line_data(darray, x, y, hue):
7981
if yplt.ndim > 1:
8082
if huename in darray.dims:
8183
otherindex = 1 if darray.dims.index(huename) == 0 else 0
82-
xplt = darray.transpose(otherdim, huename)
84+
otherdim = darray.dims[otherindex]
85+
xplt = darray.transpose(
86+
otherdim, huename, transpose_coords=False)
8387
else:
8488
raise ValueError('For 2D inputs, hue must be a dimension'
8589
+ ' i.e. one of ' + repr(darray.dims))
@@ -614,9 +618,9 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
614618
yx_dims = (ylab, xlab)
615619
dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims)
616620
if dims != darray.dims:
617-
darray = darray.transpose(*dims)
621+
darray = darray.transpose(*dims, transpose_coords=True)
618622
elif darray[xlab].dims[-1] == darray.dims[0]:
619-
darray = darray.transpose()
623+
darray = darray.transpose(transpose_coords=True)
620624

621625
# Pass the data as a masked ndarray too
622626
zval = darray.to_masked_array(copy=False)

xarray/tests/test_dataarray.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,14 +1681,14 @@ def test_math_with_coords(self):
16811681
assert_identical(expected, actual)
16821682

16831683
actual = orig[0, :] + orig[:, 0]
1684-
assert_identical(expected.T, actual)
1684+
assert_identical(expected.transpose(transpose_coords=True), actual)
16851685

1686-
actual = orig - orig.T
1686+
actual = orig - orig.transpose(transpose_coords=True)
16871687
expected = DataArray(np.zeros((2, 3)), orig.coords)
16881688
assert_identical(expected, actual)
16891689

1690-
actual = orig.T - orig
1691-
assert_identical(expected.T, actual)
1690+
actual = orig.transpose(transpose_coords=True) - orig
1691+
assert_identical(expected.transpose(transpose_coords=True), actual)
16921692

16931693
alt = DataArray([1, 1], {'x': [-1, -2], 'c': 'foo', 'd': 555}, 'x')
16941694
actual = orig + alt
@@ -1801,8 +1801,27 @@ def test_stack_nonunique_consistency(self):
18011801
assert_identical(expected, actual)
18021802

18031803
def test_transpose(self):
1804-
assert_equal(self.dv.variable.transpose(),
1805-
self.dv.transpose().variable)
1804+
da = DataArray(np.random.randn(3, 4, 5), dims=('x', 'y', 'z'),
1805+
coords={'x': range(3), 'y': range(4), 'z': range(5),
1806+
'xy': (('x', 'y'), np.random.randn(3, 4))})
1807+
1808+
actual = da.transpose(transpose_coords=False)
1809+
expected = DataArray(da.values.T, dims=('z', 'y', 'x'),
1810+
coords=da.coords)
1811+
assert_equal(expected, actual)
1812+
1813+
actual = da.transpose('z', 'y', 'x', transpose_coords=True)
1814+
expected = DataArray(da.values.T, dims=('z', 'y', 'x'),
1815+
coords={'x': da.x.values, 'y': da.y.values,
1816+
'z': da.z.values,
1817+
'xy': (('y', 'x'), da.xy.values.T)})
1818+
assert_equal(expected, actual)
1819+
1820+
with pytest.raises(ValueError):
1821+
da.transpose('x', 'y')
1822+
1823+
with pytest.warns(FutureWarning):
1824+
da.transpose()
18061825

18071826
def test_squeeze(self):
18081827
assert_equal(self.dv.variable.squeeze(), self.dv.squeeze().variable)
@@ -2258,6 +2277,23 @@ def test_groupby_restore_dim_order(self):
22582277
result = array.groupby(by).apply(lambda x: x.squeeze())
22592278
assert result.dims == expected_dims
22602279

2280+
def test_groupby_restore_coord_dims(self):
2281+
array = DataArray(np.random.randn(5, 3),
2282+
coords={'a': ('x', range(5)), 'b': ('y', range(3)),
2283+
'c': (('x', 'y'), np.random.randn(5, 3))},
2284+
dims=['x', 'y'])
2285+
2286+
for by, expected_dims in [('x', ('x', 'y')),
2287+
('y', ('x', 'y')),
2288+
('a', ('a', 'y')),
2289+
('b', ('x', 'b'))]:
2290+
result = array.groupby(by, restore_coord_dims=True).apply(
2291+
lambda x: x.squeeze())['c']
2292+
assert result.dims == expected_dims
2293+
2294+
with pytest.warns(FutureWarning):
2295+
array.groupby('x').apply(lambda x: x.squeeze())
2296+
22612297
def test_groupby_first_and_last(self):
22622298
array = DataArray([1, 2, 3, 4, 5], dims='x')
22632299
by = DataArray(['a'] * 2 + ['b'] * 3, dims='x', name='ab')
@@ -2445,15 +2481,18 @@ def test_resample_drop_nondim_coords(self):
24452481
array = ds['data']
24462482

24472483
# Re-sample
2448-
actual = array.resample(time="12H").mean('time')
2484+
actual = array.resample(
2485+
time="12H", restore_coord_dims=True).mean('time')
24492486
assert 'tc' not in actual.coords
24502487

24512488
# Up-sample - filling
2452-
actual = array.resample(time="1H").ffill()
2489+
actual = array.resample(
2490+
time="1H", restore_coord_dims=True).ffill()
24532491
assert 'tc' not in actual.coords
24542492

24552493
# Up-sample - interpolation
2456-
actual = array.resample(time="1H").interpolate('linear')
2494+
actual = array.resample(
2495+
time="1H", restore_coord_dims=True).interpolate('linear')
24572496
assert 'tc' not in actual.coords
24582497

24592498
def test_resample_keep_attrs(self):

xarray/tests/test_dataset.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4062,14 +4062,20 @@ def test_dataset_math_errors(self):
40624062

40634063
def test_dataset_transpose(self):
40644064
ds = Dataset({'a': (('x', 'y'), np.random.randn(3, 4)),
4065-
'b': (('y', 'x'), np.random.randn(4, 3))})
4065+
'b': (('y', 'x'), np.random.randn(4, 3))},
4066+
coords={'x': range(3), 'y': range(4),
4067+
'xy': (('x', 'y'), np.random.randn(3, 4))})
40664068

40674069
actual = ds.transpose()
4068-
expected = ds.apply(lambda x: x.transpose())
4070+
expected = Dataset({'a': (('y', 'x'), ds.a.values.T),
4071+
'b': (('x', 'y'), ds.b.values.T)},
4072+
coords={'x': ds.x.values, 'y': ds.y.values,
4073+
'xy': (('y', 'x'), ds.xy.values.T)})
40694074
assert_identical(expected, actual)
40704075

40714076
actual = ds.transpose('x', 'y')
4072-
expected = ds.apply(lambda x: x.transpose('x', 'y'))
4077+
expected = ds.apply(
4078+
lambda x: x.transpose('x', 'y', transpose_coords=True))
40734079
assert_identical(expected, actual)
40744080

40754081
ds = create_test_data()

xarray/tests/test_interp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def func(obj, dim, new_x):
143143
'y': da['y'],
144144
'x': ('z', xdest.values),
145145
'x2': ('z', func(da['x2'], 'x', xdest))})
146-
assert_allclose(actual, expected.transpose('z', 'y'))
146+
assert_allclose(actual,
147+
expected.transpose('z', 'y', transpose_coords=True))
147148

148149
# xdest is 2d
149150
xdest = xr.DataArray(np.linspace(0.1, 0.9, 30).reshape(6, 5),
@@ -160,7 +161,8 @@ def func(obj, dim, new_x):
160161
coords={'z': xdest['z'], 'w': xdest['w'], 'z2': xdest['z2'],
161162
'y': da['y'], 'x': (('z', 'w'), xdest),
162163
'x2': (('z', 'w'), func(da['x2'], 'x', xdest))})
163-
assert_allclose(actual, expected.transpose('z', 'w', 'y'))
164+
assert_allclose(actual,
165+
expected.transpose('z', 'w', 'y', transpose_coords=True))
164166

165167

166168
@pytest.mark.parametrize('case', [3, 4])

0 commit comments

Comments
 (0)