-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add tensordot to dataarray class also add its test to test_dataarray #731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
It seems in both unsuccesful cases it was trying to import dask and failed, I can't quite figure out what I could have changed to prevent dask from importing:
|
@@ -1369,6 +1369,27 @@ def real(self): | |||
@property | |||
def imag(self): | |||
return self._replace(self.variable.imag) | |||
|
|||
def tensordot( self, b, dims): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would call the second argument "other" rather than b
This also needs docs, minimally including a note in what's new (for v0.7.1), a docstring and a reference in the API docs. |
def tensordot( self, b, dims): | ||
a = self | ||
if not (isinstance(a, DataArray) and isinstance(b, DataArray)): | ||
raise ValueError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This error should be more descriptive -- and probably a TypeError instead
I'll just note that one other logical way to do this would be to implement |
The test is failing because we don't have dask installed in all builds -- because dask is an optional dependency, and we want to make things still work when you don't have it installed. We do have tests setup to skip if dask is not installed, by using the |
I've implemented all your comments, except:
2 Your comment: 3 Need to figure out how to edit the docs. And I was thinking: Thanks for all your help getting me going! |
Nope, it can handle any eager (numpy) and lazy (dask) xarray objects. Something like this should work:
This is a good point! Arrays with redundant dimensions are not very useful. The sane thing to do is to broadcast over dimensions that aren't being summed That said, this is difficult to implement with numpy's dot/tensordot, so perhaps it's better to simply error or omit the I also wonder if perhaps we should rename this from |
Yeah I'm going to go for omitting the dims entirely for now, I think it makes the function easy to call, and really covers what I imagine most people would use the function for, getting the relationship between two DataArrays along their common dimensions. Thats all I'm going to use it for.
Yeah that makes sense. |
|
||
def dot( self, other): | ||
"""Perform sum product of two DataArrays along their shared dims. | ||
Equivalent to taking taking tensor dot over all shared dims''' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should end in .
, not '''
Also, add a line break to separate it from the first line.
Take a look at PEP8 guidelines on spaces -- you are inserting a few extras around your parentheses :) |
a few more comments forthcoming, generally looks very good, though |
Thanks! Good comments, should help with my next pull request. : ) |
raise TypeError('dot only operates on DataArrays.') | ||
|
||
#sum over the common dims | ||
dims = list(set(s.dims) & set(other.dims) ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to turn these into a list -- a set is a perfectly valid sequence for use below, as well
I recommend running a tool like PEP8 to check the style here: https://pypi.python.org/pypi/pep8 |
|
||
Enhancements | ||
~~~~~~~~~~~~ | ||
-xarray version of np.dot :py:meth:`~DataArray.dot`. Takes the sum product over the shared dimensions of two DataArrays. Can be useful for measuring correlation over common dimensions of two DataArrays. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need "Can be useful for..."
Dot products are pretty broadly useful in science :)
You'll need to squash these into one commit ( |
685a050
to
cc0c24a
Compare
Alright I think I got rebase to work, apologies for the delay. |
@@ -1369,6 +1369,83 @@ def real(self): | |||
@property | |||
def imag(self): | |||
return self._replace(self.variable.imag) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You still have the tensordot
method defined -- you should delete it now that it's replaced by dot
:)
I'm pleased to see how this is coming along. No need to apologize for the delay -- this stuff takes time and practice to figure out :). I'd like to see more test cases so we can be confident this works properly. At the very least, we should test a
In generally, it's hard to have too many tests. One indication that you may not have enough tests comes from the "coveralls" status check, which you can find if you click on "Show all checks" at the bottom of the PR. Ideally, each PR should only increase, not decrease code coverage -- the idea is that unit tests should run over every possible code pathway. |
68b711b
to
6539fcd
Compare
Ready for review. |
raise NotImplementedError('dot products are not yet supported ' | ||
'with Dataset objects.') | ||
if not isinstance(other, DataArray): | ||
raise TypeError('dot only operates on DataArrays.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be fine to remove lines 1413:1415, and change 1417 to something like:
if not isinstance(other, DataArray):
raise TypeError('dot only operates on DataArrays, got {}'.format(type(other))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, could go either way here. I like raising NotImplementedError
because it makes it more obvious to users that this behavior might change in the future. Besides, if we get lucky a user will notice the error and then implement the missing functionality themselves ;).
I made two very minor suggestions for improvement inline, but I think the code here looks ready! The only thing we need at now is hook up the documentation! Please add |
…add tensordot to ops
6539fcd
to
ebee516
Compare
Alright, I think it's ready! |
Add tensordot to dataarray class also add its test to test_dataarray
OK, let's get this in. @deanpospisil thanks for your contribution! @jhamman I'm thinking it's probably worth issuing a 0.7.2 release shortly so we can get |
+1 on releasing 0.7.2. If we can get #782 in there too, that would be great. |
Resolving issue #723