Skip to content

Add unstack? #487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
shoyer opened this issue Sep 23, 2022 · 5 comments · Fixed by #604
Closed

Add unstack? #487

shoyer opened this issue Sep 23, 2022 · 5 comments · Fixed by #604
Labels
API extension Adds new functions or objects to the API.
Milestone

Comments

@shoyer
Copy link
Contributor

shoyer commented Sep 23, 2022

TensorFlow and PyTorch have unstack() functions (Torch calls it unbind()) for converting an array into a Python sequence of arrays, unpacked along a dimension:

>>> torch.unbind(torch.tensor([[1, 2, 3],
>>>                            [4, 5, 6],
>>>                            [7, 8, 9]]))
(tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9]))

I think this could potentially make sense in the API standard, especially because unlike in NumPy you cannot iterate over the first axis of arrays: #481

@rgommers rgommers added the API extension Adds new functions or objects to the API. label Sep 25, 2022
@rgommers
Copy link
Member

This does seem useful indeed. NumPy kinda has this as well, all the split APIs (split, array_split, hsplit, dsplit, vsplit):

>>> y = np.arange(9).reshape(3, 3) + 1
>>> y
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])
>>> np.split(y, y.shape[0])
[array([[1, 2, 3]]), array([[4, 5, 6]]), array([[7, 8, 9]])]

There's two things wrong with np.split:

  • The more general "split into equal chunks" isn't very useful,
  • It returns a list rather than a tuple

I agree that unstack is the most logical name here, to match stack.

@seberg
Copy link
Contributor

seberg commented Sep 26, 2022

Since NumPy uses iteration along the first axes, I think the "obvious" solution is tuple(arr) or list(arr). Since these behave different from splitting (they remove/unpack the first axes completely, rather than splitting/chunking it up).

As mentioned on the other issue, I also like iteraxis (or similar). The difference beeing that the result is an iterate rather than a tuple/list.

The other addition that I like is to add axis=. For a single axis, this is not particularly important (you can also moveaxis first; although it is still convenient I think).
The main advantage is making it easier to iterate/work with multiple axes.

Another related thing if we look from the iterator point of view, is the np.ndenumerate API which tracks an index. That name might also suggest nditer as a good name (NumPy already has that, but it is used for way too complicated/specific API).

Tracking an index could also be an optional argument. Although I suppose at this point, the array API may want to go with a minimal solution (i.e. wait for library need).

@rgommers
Copy link
Member

rgommers commented Sep 27, 2022

Since NumPy uses iteration along the first axes, I think the "obvious" solution is tuple(arr) or list(arr).

That's kind of circular reasoning - for a new user that doesn't already know it will split along the first axis, I don't think it is possible to predict what that will do.

As mentioned on the other issue, I also like iteraxis (or similar). The difference beeing that the result is an iterate rather than a tuple/list.

This is a decent alternative I think. There are some pros and cons to returning an iterator rather than a tuple, but it's very similar to unstack otherwise. unstack may still be a good name even if it returns an iterator.

This also brings up the question whether or not stack should accept an iterator as input.

Another related thing if we look from the iterator point of view, is the np.ndenumerate API which tracks an index. That name might also suggest nditer as a good name (NumPy already has that, but it is used for way too complicated/specific API).

Tracking an index could also be an optional argument. Although I suppose at this point, the array API may want to go with a minimal solution (i.e. wait for library need).

These APIs are rarely used today, and imho overly complicated. They seem like a poor fit for the standard.

@seberg
Copy link
Contributor

seberg commented Oct 6, 2022

Commenting here, since Leo brought that up and it struck me as a possible usability issue? If unpack returns a view on libraries that support it, but a copy on others (or copy on write?).
Are the use-cases we have in mind hampered by not knowing what the actual implementation will do?

@rgommers
Copy link
Member

rgommers commented Oct 7, 2022

As #481 (comment) discusses, unstack is a convenience API at this point. It also helps for design symmetry, since we have stack.

The current way of doing this with a comprehension:

>>> x = np.arange(6).reshape((2, 3))
>>> axis = 0
>>> tuple(x[i, ...] for i in range(x.shape[axis]))
(array([0, 1, 2]), array([3, 4, 5]))
>>> axis = 1
>>> tuple(x[:, i, ...] for i in range(x.shape[axis]))
(array([0, 3]), array([1, 4]), array([2, 5]))

Doing this in a generic way with an axis keyword is not a trivial convenience function though. The JAX unstack implementation gives a good idea:

def _unstack(x):
  return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]

def index_in_dim(operand: Array, index: int, axis: int = 0,
                 keepdims: bool = True) -> Array:
  """Convenience wrapper around slice to perform int indexing."""
  index, axis = core._canonicalize_dimension(index), int(axis)
  axis_size = operand.shape[axis]
  wrapped_index = index + axis_size if index < 0 else index
  if not 0 <= wrapped_index < axis_size:
    msg = 'index {} is out of bounds for axis {} with size {}'
    raise IndexError(msg.format(index, axis, axis_size))
  result = slice_in_dim(operand, wrapped_index, wrapped_index + 1, 1, axis)
  if keepdims:
    return result
  else:
    return lax.squeeze(result, (axis,))

unstack may still be a good name even if it returns an iterator.

@seberg said yesterday that he is fine with unstack. iteraxis is still potentially interesting as a way to iterate over axes, but is probably a separate API.

So it looks like adding unstack is fine to move ahead with, unless there are more concerns?

Are the use-cases we have in mind hampered by not knowing what the actual implementation will do?

This should not be specific to unstack. A view isn't a separate concept in the standard, and it should remain that way. There's lots of previous discussion on this repo about view and mutability, which resulted in https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html.

@steff456 steff456 mentioned this issue Feb 21, 2023
2 tasks
@rgommers rgommers added this to the v2023 milestone Mar 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API extension Adds new functions or objects to the API.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants