diff --git a/spec/draft/API_specification/manipulation_functions.rst b/spec/draft/API_specification/manipulation_functions.rst index fc0e752b9..7eb7fa8b0 100644 --- a/spec/draft/API_specification/manipulation_functions.rst +++ b/spec/draft/API_specification/manipulation_functions.rst @@ -23,6 +23,7 @@ Objects in API concat expand_dims flip + moveaxis permute_dims reshape roll diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 74b62a689..2bc929134 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -4,6 +4,7 @@ "concat", "expand_dims", "flip", + "moveaxis", "permute_dims", "reshape", "roll", @@ -114,6 +115,31 @@ def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> """ +def moveaxis( + x: array, + source: Union[int, Tuple[int, ...]], + destination: Union[int, Tuple[int, ...]], + /, +) -> array: + """ + Moves array axes (dimensions) to new positions, while leaving other axes in their original positions. + + Parameters + ---------- + x: array + input array. + source: Union[int, Tuple[int, ...]] + Axes to move. Provided axes must be unique. If ``x`` has rank (i.e, number of dimensions) ``N``, a valid axis must reside on the half-open interval ``[-N, N)``. + destination: Union[int, Tuple[int, ...]] + indices defining the desired positions for each respective ``source`` axis index. Provided indices must be unique. If ``x`` has rank (i.e, number of dimensions) ``N``, a valid axis must reside on the half-open interval ``[-N, N)``. + + Returns + ------- + out: array + an array containing reordered axes. The returned array must have the same data type as ``x``. + """ + + def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: """ Permutes the axes (dimensions) of an array ``x``.