Skip to content

Commit f507dc2

Browse files
committed
Add new unstack function to numpy/array_api namespaces
1 parent 2c85ca6 commit f507dc2

File tree

8 files changed

+32
-0
lines changed

8 files changed

+32
-0
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
88

99
## jax 0.4.27
1010

11+
* New Functionality
12+
* Added {func}`jax.numpy.unstack`, following the addition of this function in
13+
the array API 2023 standard, soon to be adopted by NumPy.
14+
1115
* Changes
1216
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
1317
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover

jax/_src/numpy/lax_numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,6 +1887,12 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike],
18871887
new_arrays.append(expand_dims(a, axis))
18881888
return concatenate(new_arrays, axis=axis, dtype=dtype)
18891889

1890+
@util.implements(getattr(np, 'unstack', None))
1891+
@partial(jit, static_argnames="axis")
1892+
def unstack(x: np.ndarray | Array, /, *, axis: int = 0) -> tuple[Array, ...]:
1893+
util.check_arraylike("unstack", x)
1894+
return tuple(moveaxis(x, axis, 0))
1895+
18901896
@util.implements(np.tile)
18911897
def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
18921898
util.check_arraylike("tile", A)

jax/experimental/array_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@
173173
roll as roll,
174174
squeeze as squeeze,
175175
stack as stack,
176+
unstack as unstack,
176177
)
177178

178179
from jax.experimental.array_api._searching_functions import (

jax/experimental/array_api/_manipulation_functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,7 @@ def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array
7272
"""Joins a sequence of arrays along a new axis."""
7373
dtype = _result_type(*arrays)
7474
return jax.numpy.stack(arrays, axis=axis, dtype=dtype)
75+
76+
def unstack(x: Array, /, *, axis: int = 0) -> tuple[Array, ...]:
77+
"""Splits an array in a sequence of arrays along the given axis."""
78+
return jax.numpy.unstack(x, axis=axis)

jax/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@
253253
unpackbits as unpackbits,
254254
unravel_index as unravel_index,
255255
unsignedinteger as unsignedinteger,
256+
unstack as unstack,
256257
unwrap as unwrap,
257258
vander as vander,
258259
vdot as vdot,

jax/numpy/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,7 @@ def unpackbits(
859859
) -> Array: ...
860860
def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: ...
861861
unsignedinteger = _np.unsignedinteger
862+
def unstack(x: _np.ndarray | Array, /, *, axis: int = ...) -> tuple[Array, ...]: ...
862863
def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = ...,
863864
axis: int = ..., period: ArrayLike = ...) -> Array: ...
864865
def vander(

tests/array_api_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@
163163
'unique_counts',
164164
'unique_inverse',
165165
'unique_values',
166+
'unstack',
166167
'var',
167168
'vecdot',
168169
'where',

tests/lax_numpy_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,20 @@ def f():
173173
for a in out]
174174
return f
175175

176+
177+
@jtu.sample_product(
178+
[dict(shape=shape, axis=axis)
179+
for shape in all_shapes
180+
for axis in list(range(-len(shape), len(shape)))],
181+
dtype=all_dtypes,
182+
)
183+
def testUnstack(self, shape, axis, dtype):
184+
rng = jtu.rand_default(self.rng())
185+
x = rng(shape, dtype)
186+
y = jnp.array(jnp.unstack(x, axis=axis))
187+
self.assertArraysEqual(jnp.moveaxis(y, 0, axis), x)
188+
189+
176190
@parameterized.parameters(
177191
[dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32,
178192
jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64,

0 commit comments

Comments
 (0)