-
Notifications
You must be signed in to change notification settings - Fork 52
Add unstack
?
#487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
This does seem useful indeed. NumPy kinda has this as well, all the >>> y = np.arange(9).reshape(3, 3) + 1
>>> y
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
>>> np.split(y, y.shape[0])
[array([[1, 2, 3]]), array([[4, 5, 6]]), array([[7, 8, 9]])] There's two things wrong with
I agree that |
Since NumPy uses iteration along the first axes, I think the "obvious" solution is As mentioned on the other issue, I also like The other addition that I like is to add Another related thing if we look from the iterator point of view, is the Tracking an index could also be an optional argument. Although I suppose at this point, the array API may want to go with a minimal solution (i.e. wait for library need). |
That's kind of circular reasoning - for a new user that doesn't already know it will split along the first axis, I don't think it is possible to predict what that will do.
This is a decent alternative I think. There are some pros and cons to returning an iterator rather than a tuple, but it's very similar to This also brings up the question whether or not
These APIs are rarely used today, and imho overly complicated. They seem like a poor fit for the standard. |
Commenting here, since Leo brought that up and it struck me as a possible usability issue? If |
As #481 (comment) discusses, The current way of doing this with a comprehension: >>> x = np.arange(6).reshape((2, 3))
>>> axis = 0
>>> tuple(x[i, ...] for i in range(x.shape[axis]))
(array([0, 1, 2]), array([3, 4, 5]))
>>> axis = 1
>>> tuple(x[:, i, ...] for i in range(x.shape[axis]))
(array([0, 3]), array([1, 4]), array([2, 5])) Doing this in a generic way with an def _unstack(x):
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
def index_in_dim(operand: Array, index: int, axis: int = 0,
keepdims: bool = True) -> Array:
"""Convenience wrapper around slice to perform int indexing."""
index, axis = core._canonicalize_dimension(index), int(axis)
axis_size = operand.shape[axis]
wrapped_index = index + axis_size if index < 0 else index
if not 0 <= wrapped_index < axis_size:
msg = 'index {} is out of bounds for axis {} with size {}'
raise IndexError(msg.format(index, axis, axis_size))
result = slice_in_dim(operand, wrapped_index, wrapped_index + 1, 1, axis)
if keepdims:
return result
else:
return lax.squeeze(result, (axis,))
@seberg said yesterday that he is fine with So it looks like adding
This should not be specific to |
TensorFlow and PyTorch have
unstack()
functions (Torch calls itunbind()
) for converting an array into a Python sequence of arrays, unpacked along a dimension:I think this could potentially make sense in the API standard, especially because unlike in NumPy you cannot iterate over the first axis of arrays: #481
The text was updated successfully, but these errors were encountered: