-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
einsum for xarray #1968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
einsum for xarray #1968
Changes from all commits
220ebcc
4239ac6
0f472a2
c83d442
1c732a4
b8d93b0
3278bf3
1ec5683
789cb96
a57907c
693b242
88be319
b3d4768
2bd06ef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ Top-level functions | |
full_like | ||
zeros_like | ||
ones_like | ||
dot | ||
|
||
Dataset | ||
======= | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,13 +6,14 @@ | |
import functools | ||
import itertools | ||
import operator | ||
from collections import Counter | ||
|
||
import numpy as np | ||
|
||
from . import duck_array_ops, utils | ||
from . import duck_array_ops, utils, dtypes | ||
from .alignment import deep_align | ||
from .merge import expand_and_merge_variables | ||
from .pycompat import OrderedDict, dask_array_type | ||
from .pycompat import OrderedDict, dask_array_type, basestring | ||
from .utils import is_dict_like | ||
|
||
_DEFAULT_FROZEN_SET = frozenset() | ||
|
@@ -937,6 +938,111 @@ def earth_mover_distance(first_samples, | |
return apply_array_ufunc(func, *args, dask=dask) | ||
|
||
|
||
def dot(*arrays, **kwargs): | ||
""" dot(*arrays, dims=None) | ||
|
||
Generalized dot product for xarray objects. Like np.einsum, but | ||
provides a simpler interface based on array dimensions. | ||
|
||
Parameters | ||
---------- | ||
arrays: DataArray (or Variable) objects | ||
Arrays to compute. | ||
dims: str or tuple of strings, optional | ||
Which dimensions to sum over. | ||
If not speciified, then all the common dimensions are summed over. | ||
|
||
Returns | ||
------- | ||
dot: DataArray | ||
|
||
Examples | ||
-------- | ||
|
||
>>> da_a = xr.DataArray(np.arange(3 * 4).reshape(3, 4), dims=['a', 'b']) | ||
>>> da_b = xr.DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5), | ||
>>> dims=['a', 'b', 'c']) | ||
>>> da_c = xr.DataArray(np.arange(5 * 6).reshape(5, 6), dims=['c', 'd']) | ||
>>> | ||
>>> xr.dot(da_a, da_b, dims=['a', 'b']).dims | ||
('c', ) | ||
>>> xr.dot(da_a, da_b, dims=['a']).dims | ||
('b', 'c') | ||
>>> xr.dot(da_a, da_b, da_c, dims=['b', 'c']).dims | ||
('a', 'd') | ||
""" | ||
from .dataarray import DataArray | ||
from .variable import Variable | ||
|
||
dims = kwargs.pop('dims', None) | ||
if len(kwargs) > 0: | ||
raise TypeError('Invalid keyward arguments {} are given'.format( | ||
list(kwargs.keys()))) | ||
|
||
if any(not isinstance(arr, (Variable, DataArray)) for arr in arrays): | ||
raise TypeError('Only xr.DataArray and xr.Variable are supported.' | ||
'Given {}.'.format([type(arr) for arr in arrays])) | ||
|
||
if len(arrays) == 0: | ||
raise TypeError('At least one array should be given.') | ||
|
||
if isinstance(dims, basestring): | ||
dims = (dims, ) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW you don't need the parentheses There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I personally like parentheses, as I think it is more descriptive. |
||
|
||
common_dims = set.intersection(*[set(arr.dims) for arr in arrays]) | ||
all_dims = [] | ||
for arr in arrays: | ||
all_dims += [d for d in arr.dims if d not in all_dims] | ||
|
||
einsum_axes = 'abcdefghijklmnopqrstuvwxyz' | ||
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} | ||
|
||
if dims is None: | ||
# find dimensions that occur more than one times | ||
dim_counts = Counter() | ||
for arr in arrays: | ||
dim_counts.update(arr.dims) | ||
dims = tuple(d for d, c in dim_counts.items() if c > 1) | ||
|
||
dims = tuple(dims) # make dims a tuple | ||
|
||
# dimensions to be parallelized | ||
broadcast_dims = tuple(d for d in all_dims | ||
if d in common_dims and d not in dims) | ||
input_core_dims = [[d for d in arr.dims if d not in broadcast_dims] | ||
for arr in arrays] | ||
output_core_dims = [tuple(d for d in all_dims if d not in | ||
dims + broadcast_dims)] | ||
|
||
# we use tensordot if possible, because it is more efficient for dask | ||
if len(broadcast_dims) == 0 and len(arrays) == 2: | ||
axes = [[arr.get_axis_num(d) for d in arr.dims if d in dims] | ||
for arr in arrays] | ||
return apply_ufunc(duck_array_ops.tensordot, *arrays, dask='allowed', | ||
input_core_dims=input_core_dims, | ||
output_core_dims=output_core_dims, | ||
kwargs={'axes': axes}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I added a path for tensordot, which dask can compute more efficiently. |
||
|
||
# construct einsum subscripts, such as '...abc,...ab->...c' | ||
# Note: input_core_dims are always moved to the last position | ||
subscripts_list = ['...' + ''.join([dim_map[d] for d in ds]) for ds | ||
in input_core_dims] | ||
subscripts = ','.join(subscripts_list) | ||
subscripts += '->...' + ''.join([dim_map[d] for d in output_core_dims[0]]) | ||
|
||
# dtype estimation is necessary for dask='parallelized' | ||
out_dtype = dtypes.result_type(*arrays) | ||
|
||
# subscripts should be passed to np.einsum as arg, not as kwargs. We need | ||
# to construct a partial function for parallelized computation. | ||
func = functools.partial(np.einsum, subscripts) | ||
result = apply_ufunc(func, *arrays, | ||
input_core_dims=input_core_dims, | ||
output_core_dims=output_core_dims, | ||
dask='parallelized', output_dtypes=[out_dtype]) | ||
return result.transpose(*[d for d in all_dims if d in result.dims]) | ||
|
||
|
||
def where(cond, x, y): | ||
"""Return elements from `x` or `y` depending on `cond`. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if you write
xr.dot()
? I suppose we still need to raise an error for 0 arguments.