-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Implement tensordot for xarray with dask support #723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Yes, this would be a nice addition! I spent a little bit of a time futzing around with this to see if there is an elegant way to plug this into our existing dispatching system. The short of it is that the answer appears to be no -- we don't have any elegant equivalent to dask.array's generic So, for now I would simply write a function specialized to DataArray objects. Something like the following (barely tested) is a starting point: from xarray import align, DataArray
# note: using private imports (e.g., from xarray.core) is definitely discouraged!
# this is not guaranteed to work in future versions of xarray
from xarray.core.ops import _dask_or_eager_func
def tensordot(a, b, dims):
if not (isinstance(a, DataArray) and isinstance(b, DataArray)):
raise ValueError
a, b = align(a, b, join='inner', copy=False)
axes = (a.get_axis_num(dims), b.get_axis_num(dims))
f = _dask_or_eager_func('tensordot', n_array_args=2)
new_data = f(a.data, b.data, axes=axes)
if isinstance(dims, basestring):
dims = [dims]
new_coords = a.coords.merge(b.coords).drop(dims)
new_dims = ([d for d in a.dims if d not in dims] +
[d for d in b.dims if d not in dims])
return DataArray(new_data, new_coords, new_dims) This would be worth cleaning up so we could add it to the codebase (mostly documentation & tests). |
@MaximilianR I do like That said, I suppose you could make an xarray version of |
@shoyer - I thought your answer dominated mine, so I left yours as the only response. |
Looks like it can perform tensor dot for dask and straight xarrays! But apparently dask has not implemented tensordot with multiple axes arguments, and it also does not work performing a tensor dot between a dask xarray and an xarray. Neither of these cases worries me too much, hopefully they don't worry you. from xarray import align, DataArray
#note: using private imports (e.g., from xarray.core) is definitely discouraged!
#this is not guaranteed to work in future versions of xarray
from xarray.core.ops import _dask_or_eager_func
def tensordot(a, b, dims):
if not (isinstance(a, DataArray) and isinstance(b, DataArray)):
raise ValueError
a, b = align(a, b, join='inner', copy=False)
axes = (a.get_axis_num(dims), b.get_axis_num(dims))
f = _dask_or_eager_func('tensordot', n_array_args=2)
new_data = f(a.data, b.data, axes=axes)
if isinstance(dims, str):
dims = [dims]
new_coords = a.coords.merge(b.coords).drop(dims)
#drop the dims you are performing the sum product over
new_dims = ([d for d in a.dims if d not in dims] +
[d for d in b.dims if d not in dims])
return DataArray(new_data, new_coords, new_dims)
import xarray as xr
import numpy as np
x_trans = np.linspace(-3,3,6)
y_trans = np.linspace(-3,3,5)
imgID = range(4)
da = xr.DataArray( np.ones((6,5,4)),
coords = [ x_trans, y_trans, imgID ],
dims = ['x_trans', 'y_trans', 'imgID'] )
models = range(20)
dm = xr.DataArray( np.ones(( 20 , 5, 4 )),
coords = [ models, y_trans, imgID],
dims = [ 'models', 'y_trans', 'imgID' ] )
#xarray tensordot
proj_a = tensordot(da, dm, 'imgID')
#dask xarray tensor dot
da = da.chunk()
dm = dm.chunk()
proj_b = tensordot(da, dm, 'imgID')
#errors
#multiple dims
proj_c = tensordot(da, dm, ['imgID', 'y_trans'])
#mixed types
da = da.chunk()
dm = dm.load()
proj_d = tensordot(da, dm, 'imgID') |
I wasn't sure where the best place to put the def would be. Currently I have been running it from the xarray class: |
Also that einsum does seem pretty ideal. I'll see if I can get it running in dask, so we can port it over here. |
I'm split on whether a function or method makes more sense ( |
Fixed by #731. |
I've started using X-ray to store responses from convolutional neural nets over different transformations of images (translation(x,y), rotation (radians), etc). So far its been very intuitive storing and transforming results, unfortunately much of my analysis requires the use of tensor dot products, where I can choose arbitrary dimensions over which to make a projection, or perform a correlation. While dask implements np.tensordot, xray does not.
One can implement a dot product manually by multiplying data arrays then summing over dimensions.
fitm = (da_response*da_model).sum('imageID').sum('x_translation').max('models')
but this ends up being very slow, as I imagine when dot products are implemented by numpy or dask, there is a fair amount of optimization going on.
I am relatively new to GitHub, and this project, would you have any advice on the best way to contribute this functionality? tensordot where in you can put in a list of dimension names in two dataarray over which to compute a sum product, using dasks implementation.
The text was updated successfully, but these errors were encountered: