Skip to content

Add __array_namespace__ to ArrayBox#647

Merged
agriyakhetarpal merged 3 commits intoHIPS:masterfrom
yaugenst-flex:yaugenst-flex/arraybox-namespace
Oct 15, 2024
Merged

Add __array_namespace__ to ArrayBox#647
agriyakhetarpal merged 3 commits intoHIPS:masterfrom
yaugenst-flex:yaugenst-flex/arraybox-namespace

Conversation

@yaugenst-flex
Copy link
Contributor

@yaugenst-flex yaugenst-flex commented Oct 10, 2024

Small change that adds the __array_namespace__ method from the Array API standard to autograd's ArrayBox.

This makes autograd (partially) compatible with libraries that check for array API compatibility, such as xarray, because xarray will coerce types that do not implement the array API into regular numpy arrays, see relevant discussion here.

The following code snippet:

import xarray as xr

import autograd.numpy as np
from autograd import grad


def f(x):
    data = xr.DataArray(x, dims=range(x.ndim)).data
    print(data)
    return np.sum(data)


x = np.random.uniform(-1, 1, 10)
grad(f)(x)

prints

[<autograd.numpy.numpy_boxes.ArrayBox object at 0x7f9465347640>
 <autograd.numpy.numpy_boxes.ArrayBox object at 0x7f94652f7000>
 <autograd.numpy.numpy_boxes.ArrayBox object at 0x7f94652f7c40>]

currently, because the traced ArrayBox is being coerced into a regular numpy array of dtype=object with ArrayBox elements. Differentiating through this works (with some limitations), but is very inefficient.

With the addition suggested in this PR, the snippet instead prints:

Autograd ArrayBox with value [1. 1. 1.]

i.e. data is now a "proper" ArrayBox.

I might also be interested in working on a full xarray wrapper in the future, but for now just this change would already help a lot.

@fjosw
Copy link
Collaborator

fjosw commented Oct 13, 2024

Hi @yaugenst-flex, thanks for contributing to autograd! I made the type hints compatible with python<3.10 and now all tests seem to pass 👍 Any objections against merging this @agriyakhetarpal @j-towns ?

@fjosw fjosw requested a review from agriyakhetarpal October 13, 2024 21:00
@j-towns
Copy link
Collaborator

j-towns commented Oct 14, 2024

No objections from me.

Copy link
Collaborator

@agriyakhetarpal agriyakhetarpal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No objections from me either! Thanks, @yaugenst-flex!

@agriyakhetarpal agriyakhetarpal merged commit 317be57 into HIPS:master Oct 15, 2024
@agriyakhetarpal
Copy link
Collaborator

I've created https://github.com/HIPS/autograd/milestone/1 to track things we can add to a v1.8.0 release (or to see which things we can split up and add to a v1.7.1 release). This sounds like a useful improvement.

@agriyakhetarpal agriyakhetarpal added this to the v1.8.0 milestone Oct 15, 2024
@yaugenst-flex
Copy link
Contributor Author

This is great, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants