Skip to content

Implement tensordot for xarray with dask support #723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
deanpospisil opened this issue Jan 26, 2016 · 8 comments
Closed

Implement tensordot for xarray with dask support #723

deanpospisil opened this issue Jan 26, 2016 · 8 comments

Comments

@deanpospisil
Copy link

I've started using X-ray to store responses from convolutional neural nets over different transformations of images (translation(x,y), rotation (radians), etc). So far its been very intuitive storing and transforming results, unfortunately much of my analysis requires the use of tensor dot products, where I can choose arbitrary dimensions over which to make a projection, or perform a correlation. While dask implements np.tensordot, xray does not.

One can implement a dot product manually by multiplying data arrays then summing over dimensions.

fitm = (da_response*da_model).sum('imageID').sum('x_translation').max('models')

but this ends up being very slow, as I imagine when dot products are implemented by numpy or dask, there is a fair amount of optimization going on.

I am relatively new to GitHub, and this project, would you have any advice on the best way to contribute this functionality? tensordot where in you can put in a list of dimension names in two dataarray over which to compute a sum product, using dasks implementation.

@shoyer
Copy link
Member

shoyer commented Jan 26, 2016

Yes, this would be a nice addition!

I spent a little bit of a time futzing around with this to see if there is an elegant way to plug this into our existing dispatching system. The short of it is that the answer appears to be no -- we don't have any elegant equivalent to dask.array's generic atop method.

So, for now I would simply write a function specialized to DataArray objects. Something like the following (barely tested) is a starting point:

from xarray import align, DataArray

# note: using private imports (e.g., from xarray.core) is definitely discouraged!
# this is not guaranteed to work in future versions of xarray
from xarray.core.ops import _dask_or_eager_func

def tensordot(a, b, dims):
    if not (isinstance(a, DataArray) and isinstance(b, DataArray)):
        raise ValueError

    a, b = align(a, b, join='inner', copy=False)

    axes = (a.get_axis_num(dims), b.get_axis_num(dims))
    f = _dask_or_eager_func('tensordot', n_array_args=2)
    new_data = f(a.data, b.data, axes=axes)

    if isinstance(dims, basestring):
        dims = [dims]

    new_coords = a.coords.merge(b.coords).drop(dims)

    new_dims = ([d for d in a.dims if d not in dims] +
                [d for d in b.dims if d not in dims])

    return DataArray(new_data, new_coords, new_dims)

This would be worth cleaning up so we could add it to the codebase (mostly documentation & tests).

@shoyer
Copy link
Member

shoyer commented Jan 26, 2016

@MaximilianR I do like einsum, but I'm not sure the API would be a good fit for xarray (we already have dimension names), and it also does not exist yet for dask (dask/dask#732).

That said, I suppose you could make an xarray version of einsum with syntax that looks more like tensordot with *args, e.g., einsum(a, b, c, dims=('x', 'y')).

@max-sixty
Copy link
Collaborator

@shoyer - I thought your answer dominated mine, so I left yours as the only response.
But yup, that form of einsum would be pretty nice...

@shoyer shoyer changed the title Implementing dask tensordot Implement tensordot for xarray with dask support Jan 26, 2016
@deanpospisil
Copy link
Author

Looks like it can perform tensor dot for dask and straight xarrays! But apparently dask has not implemented tensordot with multiple axes arguments, and it also does not work performing a tensor dot between a dask xarray and an xarray. Neither of these cases worries me too much, hopefully they don't worry you.

from xarray import align, DataArray

#note: using private imports (e.g., from xarray.core) is definitely discouraged!
#this is not guaranteed to work in future versions of xarray

from xarray.core.ops import _dask_or_eager_func

def tensordot(a, b, dims):
    if not (isinstance(a, DataArray) and isinstance(b, DataArray)):
        raise ValueError

    a, b = align(a, b, join='inner', copy=False)

    axes = (a.get_axis_num(dims), b.get_axis_num(dims))
    f = _dask_or_eager_func('tensordot', n_array_args=2)
    new_data = f(a.data, b.data, axes=axes)

    if isinstance(dims, str):
        dims = [dims]

    new_coords = a.coords.merge(b.coords).drop(dims)

    #drop the dims you are performing the sum product over
    new_dims = ([d for d in a.dims if d not in dims] +
                [d for d in b.dims if d not in dims])

    return DataArray(new_data, new_coords, new_dims)

import xarray as xr
import numpy as np

x_trans = np.linspace(-3,3,6)
y_trans = np.linspace(-3,3,5)
imgID = range(4)
da = xr.DataArray( np.ones((6,5,4)), 
coords = [ x_trans, y_trans, imgID ], 
dims = ['x_trans', 'y_trans', 'imgID'] )

models = range(20)
dm = xr.DataArray( np.ones(( 20 , 5, 4 )), 
coords = [ models, y_trans, imgID], 
dims = [ 'models', 'y_trans', 'imgID' ] )

#xarray tensordot
proj_a = tensordot(da, dm, 'imgID')

#dask xarray tensor dot
da = da.chunk()
dm = dm.chunk()
proj_b = tensordot(da, dm, 'imgID')

#errors
#multiple dims
proj_c = tensordot(da, dm, ['imgID', 'y_trans'])

#mixed types
da = da.chunk()
dm = dm.load()
proj_d = tensordot(da, dm, 'imgID')

@deanpospisil
Copy link
Author

I wasn't sure where the best place to put the def would be. Currently I have been running it from the xarray class:
t = da1.tensordot( da2, 'shapes' )
Let me know if that seems alright, then I'll write some simple tests in test_dataarray for tensor dot.
Maybe make my first pull request!

@deanpospisil
Copy link
Author

Also that einsum does seem pretty ideal. I'll see if I can get it running in dask, so we can port it over here.

@shoyer
Copy link
Member

shoyer commented Jan 27, 2016

I'm split on whether a function or method makes more sense (a.tensordot(b, dim='x') vs xr.tensordot(a, b, dim='x')). I would be OK with either, so yes, please do go ahead!

@shoyer
Copy link
Member

shoyer commented Mar 5, 2016

Fixed by #731.

@shoyer shoyer closed this as completed Mar 5, 2016
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants