Skip to content

enable xr.ALL_DIMS in xr.dot #3424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ New Features
to reduce over all dimensions. While we have no plans to remove `xr.ALL_DIMS`, we suggest
using `...`.
By `Maximilian Roos <https://github.com/max-sixty>`_
- :py:func:`~xarray.dot`, and :py:func:`~xarray.DataArray.dot` now support the
`dims=...` option to sum over the union of dimensions of all input arrays
(:issue:`3423`) by `Mathias Hauser <https://github.com/mathause>`_.
- Added integration tests against `pint <https://pint.readthedocs.io/>`_.
(:pull:`3238`) by `Justus Magin <https://github.com/keewis>`_.

Expand Down
20 changes: 15 additions & 5 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,9 +1055,9 @@ def dot(*arrays, dims=None, **kwargs):
----------
arrays: DataArray (or Variable) objects
Arrays to compute.
dims: str or tuple of strings, optional
Which dimensions to sum over.
If not speciified, then all the common dimensions are summed over.
dims: '...', str or tuple of strings, optional
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
If not specified, then all the common dimensions are summed over.
**kwargs: dict
Additional keyword arguments passed to numpy.einsum or
dask.array.einsum
Expand All @@ -1070,7 +1070,7 @@ def dot(*arrays, dims=None, **kwargs):
--------

>>> import numpy as np
>>> import xarray as xp
>>> import xarray as xr
>>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=['a', 'b'])
>>> da_b = xr.DataArray(np.arange(3 * 2 * 2).reshape(3, 2, 2),
... dims=['a', 'b', 'c'])
Expand Down Expand Up @@ -1117,6 +1117,14 @@ def dot(*arrays, dims=None, **kwargs):
[273, 446, 619]])
Dimensions without coordinates: a, d

>>> xr.dot(da_a, da_b)
<xarray.DataArray (c: 2)>
array([110, 125])
Dimensions without coordinates: c

>>> xr.dot(da_a, da_b, dims=...)
<xarray.DataArray ()>
array(235)
"""
from .dataarray import DataArray
from .variable import Variable
Expand All @@ -1141,7 +1149,9 @@ def dot(*arrays, dims=None, **kwargs):
einsum_axes = "abcdefghijklmnopqrstuvwxyz"
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}

if dims is None:
if dims is ...:
dims = all_dims
elif dims is None:
# find dimensions that occur more than one times
dim_counts = Counter()
for arr in arrays:
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2742,9 +2742,9 @@ def dot(
----------
other : DataArray
The other array with which the dot product is performed.
dims: hashable or sequence of hashables, optional
Along which dimensions to be summed over. Default all the common
dimensions are summed over.
dims: '...', hashable or sequence of hashables, optional
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
If not specified, then all the common dimensions are summed over.

Returns
-------
Expand Down
17 changes: 17 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,23 @@ def test_dot(use_dask):
assert actual.dims == ("b",)
assert (actual.data == np.zeros(actual.shape)).all()

# Ellipsis (...) sums over all dimensions
actual = xr.dot(da_a, da_b, dims=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij,ijk->", a, b)).all()

actual = xr.dot(da_a, da_b, da_c, dims=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij,ijk,kl-> ", a, b, c)).all()

actual = xr.dot(da_a, dims=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij-> ", a)).all()

actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims=...)
assert actual.dims == ()
assert (actual.data == np.zeros(actual.shape)).all()

# Invalid cases
if not use_dask:
with pytest.raises(TypeError):
Expand Down
10 changes: 10 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3925,6 +3925,16 @@ def test_dot(self):
expected = DataArray(expected_vals, coords=[x, j], dims=["x", "j"])
assert_equal(expected, actual)

# Ellipsis: all dims are shared
actual = da.dot(da, dims=...)
expected = da.dot(da)
assert_equal(expected, actual)

# Ellipsis: not all dims are shared
actual = da.dot(dm, dims=...)
expected = da.dot(dm, dims=("j", "x", "y", "z"))
assert_equal(expected, actual)

with pytest.raises(NotImplementedError):
da.dot(dm.to_dataset(name="dm"))
with pytest.raises(TypeError):
Expand Down