Skip to content

RFC: add support for a tuple of axes in expand_dims #760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
izaid opened this issue Mar 10, 2024 · 7 comments
Open

RFC: add support for a tuple of axes in expand_dims #760

izaid opened this issue Mar 10, 2024 · 7 comments
Labels
Needs Discussion Needs further discussion. RFC Request for comments. Feature requests and proposed changes. topic: Manipulation Array manipulation and transformation.
Milestone

Comments

@izaid
Copy link

izaid commented Mar 10, 2024

Hello all! I raised this issue on array-api-compat earlier (data-apis/array-api-compat#105), but I think it might be more properly directed here.

In the array API, expand_dims supports only a single axis (https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html) as opposed to a tuple of axes. This is different to NumPy, CuPy, and Jax, which support a tuple of axes. PyTorch, however, supports only a single axis. I don't know the justification for why the array API supports only a single axis as opposed to a tuple, but what it means is that expand_dims no longer works in many places when adopting the array API.

In practice, expand_dims is just a light wrapper for reshape, see https://github.com/numpy/numpy/blob/3b246c6488cf246d488bbe5726ca58dc26b6ea74/numpy/lib/_shape_base_impl.py#L594. But it's not great to force users to write their own version of expand_dims in every library now. Is the array API willing to update expand_dims to support a tuple of axes? If not, and if expand_dims will only support a single axis going forward, that effectively makes all users of expand_dims copy and paste the NumPy implementation.

@lucascolley Pointed out to me that when expand_dims was added to the array API, only NumPy supported a tuple of axes. See #42. That was 4 years ago and the situation has changed, as above.

@asmeurer
Copy link
Member

asmeurer commented Mar 11, 2024

Seems tuple support was omitted because torch doesn't support it #42. I found a few feature requests for it for torch.unsqueeze (the PyTorch equivalent to expand_dims) pytorch/pytorch#30702, pytorch/pytorch#4692 (comment). Seems it was intentionally omitted due to the ambiguity that arises from mixing negative and positive indices.

I agree this ambiguity is a potential concern. If we standardize this, we should somehow only require a subset of behavior that omits this ambiguity, e.g., by leaving the mixing of negative and nonnegative indices unspecified.

Consider for example:

>>> np.expand_dims(np.empty((2,)), (1, -1)).shape
(2, 1, 1)

The resulting shape has 1 in positions 1 and -1, but a result shape of (2, 1) would also satisfy this. I suppose one could argue that exactly len(axes) dimensions should be added.

But also consider

>>> np.expand_dims(np.empty((2, 3, 4, 5)), (3, -3)).shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/aaronmeurer/miniconda3/envs/array-apis/lib/python3.11/site-packages/numpy/lib/shape_base.py", line 597, in expand_dims
    axis = normalize_axis_tuple(axis, out_ndim)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronmeurer/miniconda3/envs/array-apis/lib/python3.11/site-packages/numpy/core/numeric.py", line 1385, in normalize_axis_tuple
    raise ValueError('repeated axis')
ValueError: repeated axis

There's no way to insert 1 dimensions into (2, 3, 4, 5) so that they appear at indices 3 and -3.

Here's a small proof. There's no length list where you can remove indices 3 and -3 and result in a list of length 4
>>> def remove_indices(n, idxes):
...     """Return range(n) with `idxes` indices removed"""
...     x = list(range(n))
...     vals = [x[i] for i in idxes]
...     for v in vals:
...         try:
...             x.remove(v)
...         except ValueError: # Already removed
...             pass
...     return x
>>> [remove_indices(n, (-3, 3)) for n in range(4, 10)]
[[0, 2], [0, 1, 4], [0, 1, 2, 4, 5], [0, 1, 2, 5, 6], [0, 1, 2, 4, 6, 7], [0, 1, 2, 4, 5, 7, 8]]
>>> [len(remove_indices(n, (-3, 3))) for n in range(4, 10)]
[2, 3, 5, 5, 6, 7]

At the same time, if the goal of expand_dims is for the axes to refer to the dimensions after unsqueezing/expanding, then it's not exactly trivial to do it as a sequence of expand_dims, because if you apply the expansion in the wrong order you will break the position of previous dimensions (the correct logic is not hard, but it's the sort of thing that's easy to get wrong). So I think there is value in having native support for multiple axes.

@kgryte kgryte changed the title expand_dims for tuple of axes RFC: add support for a tuple of axes in expand_dims Apr 4, 2024
@kgryte kgryte added RFC Request for comments. Feature requests and proposed changes. topic: Manipulation Array manipulation and transformation. Needs Discussion Needs further discussion. labels Apr 4, 2024
@kgryte kgryte added this to the v2024 milestone Apr 4, 2024
@Micky774
Copy link
Contributor

Regarding removing ambiguity, I think it would suffice to impose an ordering in which to prefer expanding dims right? For example, if we specify "negative indices get resolved first" then your borrowing your example above could be resolved as

x = np.empty((2, 3, 4, 5))
xp.expand_dims(x., (3, -3)) == np.expand_dims(np.expand_dims(x, -3), 3)

so that the final output shape is (2, 3, 1, 1, 4, 5), which seems reasonable.

Still, I'm not sure if it is worth it since in the first place users could do it in a two-step expansion (albeit with some more thought), and the resolution order (+ or - indices first?) is rather arbitrary.

@asmeurer
Copy link
Member

When you do repeated expand_dims, the inserted dimensions in the final shape won't necessarily be in the indices you initially specified (that's the whole point of this feature request, that you need a way to do them all at once). (2, 3, 1, 1, 4, 5) has 1s at indices 2 and -3 (remember 0-based indexing), because the 1 that was at index -3 got shifted over.

@lucascolley
Copy link
Member

in case it affects this making v2024 either way, this is now available as https://data-apis.org/array-api-extra/generated/array_api_extra.expand_dims.html

@kgryte
Copy link
Contributor

kgryte commented Jan 20, 2025

Given the ambiguity of supporting a tuple of axes in expand_dims, I wonder if there is room for an alternative API which avoids the ambiguity altogether. Namely,

def spread_dims(x: array, ndims: int, axes=Tuple[int, ...]) -> array

which expands the shape of an input array x to have ndims and where the current dimensions are explicitly mapped to a unique list of axes in the resulting array. All unspecified axes must be singleton dimensions.

This essentially flips the problem into one in which you specify where you want the non-singleton dimensions, rather than where you want to insert the singleton dimensions.

@lucascolley
Copy link
Member

That sounds like a good idea if anybody takes issue with the interpretation chosen for xpx.expand_dims. But the milestone can probably be bumped (or removed until that happens).

@kgryte kgryte modified the milestones: v2024, v2025 Jan 20, 2025
@jakevdp
Copy link

jakevdp commented Feb 6, 2025

It seems to me the behavior of expand_dims with multiple axes could be specified without any of the ambiguities mentioned above. Basically, for y = expand_dims(x, axes) when axes is a tuple:

  • the output y must have dimension x.ndim + len(axes)
  • each entry of axes must be unique, with negative indices normalized in relation to y.ndim (not x.ndim), or else a ValueError is raised.
  • for each entry axis of axes, y.shape[axis] must be 1.
  • remaining dimensions of y consist of the dimensions of x in order.

This basically describes the existing behavior of NumPy, and handles all the ambiguities mentioned above:

  • if x has shape (2, 3, 4, 5), then expand_dims(x, (3, -3)) fails because the axes list has duplicate entries (they are normalized to (3, 3)); a ValueError should be raised.
  • if x has shape (2,), then expand_dims(x, (-1, 1)) will have shape (2, 1, 1), as the indices are normalized to (2, 1)

This behavior is semantically equivalent to calling expand_dims repeatedly with a single axis, only when the axes tuple is normalized to positive values using the final shape, is sorted, and contains no duplicates.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Needs Discussion Needs further discussion. RFC Request for comments. Feature requests and proposed changes. topic: Manipulation Array manipulation and transformation.
Projects
Status: Stage 0
Development

No branches or pull requests

6 participants