Skip to content

test_fft.py #196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 10, 2023
2 changes: 1 addition & 1 deletion array_api_tests/_array_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __repr__(self):
_constants = ["e", "inf", "nan", "pi"]
_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]
_funcs += ["take", "isdtype", "conj", "imag", "real"] # TODO: bump spec and update array-api-tests to new spec layout
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS + ["fft"]

for attr in _top_level_attrs:
try:
Expand Down
9 changes: 8 additions & 1 deletion array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
integers, just, lists, none, one_of,
sampled_from, shared)

from . import _array_module as xp
from . import _array_module as xp, api_version
from . import dtype_helpers as dh
from . import shape_helpers as sh
from . import xps
Expand Down Expand Up @@ -141,6 +141,13 @@ def oneway_broadcastable_shapes(draw) -> SearchStrategy[OnewayBroadcastableShape
return OnewayBroadcastableShapes(input_shape, result_shape)


def all_floating_dtypes() -> SearchStrategy[DataType]:
strat = xps.floating_dtypes()
if api_version >= "2022.12":
strat |= xps.complex_dtypes()
return strat


# shared() allows us to draw either the function or the function name and they
# will both correspond to the same function.

Expand Down
13 changes: 13 additions & 0 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def assert_dtype(
>>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int)

"""
__tracebackhide__ = True
in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype]
f_in_dtypes = dh.fmt_types(tuple(in_dtypes))
f_out_dtype = dh.dtype_to_name[out_dtype]
Expand Down Expand Up @@ -149,6 +150,7 @@ def assert_kw_dtype(
>>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype)

"""
__tracebackhide__ = True
f_kw_dtype = dh.dtype_to_name[kw_dtype]
f_out_dtype = dh.dtype_to_name[out_dtype]
msg = (
Expand All @@ -166,6 +168,7 @@ def assert_default_float(func_name: str, out_dtype: DataType):
>>> assert_default_float('ones', out.dtype)

"""
__tracebackhide__ = True
f_dtype = dh.dtype_to_name[out_dtype]
f_default = dh.dtype_to_name[dh.default_float]
msg = (
Expand All @@ -183,6 +186,7 @@ def assert_default_complex(func_name: str, out_dtype: DataType):
>>> assert_default_complex('asarray', out.dtype)

"""
__tracebackhide__ = True
f_dtype = dh.dtype_to_name[out_dtype]
f_default = dh.dtype_to_name[dh.default_complex]
msg = (
Expand All @@ -200,6 +204,7 @@ def assert_default_int(func_name: str, out_dtype: DataType):
>>> assert_default_int('full', out.dtype)

"""
__tracebackhide__ = True
f_dtype = dh.dtype_to_name[out_dtype]
f_default = dh.dtype_to_name[dh.default_int]
msg = (
Expand All @@ -217,6 +222,7 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty
>>> assert_default_int('argmax', out.dtype)

"""
__tracebackhide__ = True
f_dtype = dh.dtype_to_name[out_dtype]
msg = (
f"{repr_name}={f_dtype}, should be the default index dtype, "
Expand All @@ -240,6 +246,7 @@ def assert_shape(
>>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3))

"""
__tracebackhide__ = True
if isinstance(out_shape, int):
out_shape = (out_shape,)
if isinstance(expected, int):
Expand Down Expand Up @@ -273,6 +280,7 @@ def assert_result_shape(
>>> assert out.shape == (3, 3)

"""
__tracebackhide__ = True
if expected is None:
expected = sh.broadcast_shapes(*in_shapes)
f_in_shapes = " . ".join(str(s) for s in in_shapes)
Expand Down Expand Up @@ -307,6 +315,7 @@ def assert_keepdimable_shape(
>>> assert out2.shape == (1, 1)

"""
__tracebackhide__ = True
if keepdims:
shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape))
else:
Expand Down Expand Up @@ -337,6 +346,7 @@ def assert_0d_equals(
>>> assert res[0] == x[0]

"""
__tracebackhide__ = True
msg = (
f"{out_repr}={out_val}, but should be {x_repr}={x_val} "
f"[{func_name}({fmt_kw(kw)})]"
Expand Down Expand Up @@ -369,6 +379,7 @@ def assert_scalar_equals(
>>> assert int(out) == 5

"""
__tracebackhide__ = True
repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
f_func = f"{func_name}({fmt_kw(kw)})"
if type_ in [bool, int]:
Expand Down Expand Up @@ -401,6 +412,7 @@ def assert_fill(
>>> assert xp.all(out == 42)

"""
__tracebackhide__ = True
msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}"
if cmath.isnan(fill_value):
assert xp.all(xp.isnan(out)), msg
Expand Down Expand Up @@ -443,6 +455,7 @@ def assert_array_elements(
>>> assert xp.all(out == x)

"""
__tracebackhide__ = True
dh.result_type(out.dtype, expected.dtype) # sanity check
assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
f_func = f"[{func_name}({fmt_kw(kw)})]"
Expand Down
6 changes: 4 additions & 2 deletions array_api_tests/shape_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from itertools import product
from typing import Iterator, List, Optional, Tuple, Union
from typing import Iterator, List, Optional, Sequence, Tuple, Union

from ndindex import iter_indices as _iter_indices

Expand Down Expand Up @@ -66,10 +66,12 @@ def broadcast_shapes(*shapes: Shape):


def normalise_axis(
axis: Optional[Union[int, Tuple[int, ...]]], ndim: int
axis: Optional[Union[int, Sequence[int]]], ndim: int
) -> Tuple[int, ...]:
if axis is None:
return tuple(range(ndim))
elif isinstance(axis, Sequence) and not isinstance(axis, tuple):
axis = tuple(axis)
axes = axis if isinstance(axis, tuple) else (axis,)
axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes)
return axes
Expand Down
2 changes: 1 addition & 1 deletion array_api_tests/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
all_funcs.extend(funcs)
name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs}

EXTENSIONS: str = ["linalg"]
EXTENSIONS: List[str] = ["linalg"] # TODO: add "fft" once stubs available
extension_to_funcs: Dict[str, List[FunctionType]] = {}
for ext in EXTENSIONS:
mod = name_to_mod[ext]
Expand Down
Loading