From da8e3748f868e581134fa0e4f44b3491618d40d3 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 6 Sep 2023 11:05:30 +0100 Subject: [PATCH 01/13] Rudimentary `test_fft` --- array_api_tests/_array_module.py | 2 +- array_api_tests/stubs.py | 2 +- array_api_tests/test_fft.py | 38 ++++++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) create mode 100644 array_api_tests/test_fft.py diff --git a/array_api_tests/_array_module.py b/array_api_tests/_array_module.py index 899a2591..8a7c7887 100644 --- a/array_api_tests/_array_module.py +++ b/array_api_tests/_array_module.py @@ -63,7 +63,7 @@ def __repr__(self): _constants = ["e", "inf", "nan", "pi"] _funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs] _funcs += ["take", "isdtype", "conj", "imag", "real"] # TODO: bump spec and update array-api-tests to new spec layout -_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS +_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS + ["fft"] for attr in _top_level_attrs: try: diff --git a/array_api_tests/stubs.py b/array_api_tests/stubs.py index 69ec886d..0134765b 100644 --- a/array_api_tests/stubs.py +++ b/array_api_tests/stubs.py @@ -52,7 +52,7 @@ all_funcs.extend(funcs) name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs} -EXTENSIONS: str = ["linalg"] +EXTENSIONS: List[str] = ["linalg"] # TODO: add "fft" once stubs available extension_to_funcs: Dict[str, List[FunctionType]] = {} for ext in EXTENSIONS: mod = name_to_mod[ext] diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py new file mode 100644 index 00000000..4c6acd77 --- /dev/null +++ b/array_api_tests/test_fft.py @@ -0,0 +1,38 @@ +import math + +import pytest +from hypothesis import given + +from array_api_tests.typing import DataType + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import xps + +pytestmark = [ + pytest.mark.ci, + pytest.mark.xp_extension("fft"), + pytest.mark.min_version("draft"), +] + + +fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1) + + +def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType): + if in_dtype == xp.float32: + expected = xp.complex64 + else: + assert in_dtype == xp.float64 # sanity check + expected = xp.complex128 + ph.assert_dtype( + func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected + ) + + +@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat)) +def test_fft(x): + out = xp.fft.fft(x) + assert_fft_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("fft", out_shape=out.shape, expected=x.shape) From 7c315973a6aa5e06ba09eceec127ac6ea06550bb Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 6 Sep 2023 13:39:10 +0100 Subject: [PATCH 02/13] Rudimentary tests for `ifft`, `fftn` and `ifftn` --- array_api_tests/test_fft.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py index 4c6acd77..823d3c3d 100644 --- a/array_api_tests/test_fft.py +++ b/array_api_tests/test_fft.py @@ -36,3 +36,24 @@ def test_fft(x): out = xp.fft.fft(x) assert_fft_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("fft", out_shape=out.shape, expected=x.shape) + + +@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat)) +def test_ifft(x): + out = xp.fft.ifft(x) + assert_fft_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("ifft", out_shape=out.shape, expected=x.shape) + + +@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat)) +def test_fftn(x): + out = xp.fft.fftn(x) + assert_fft_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("fftn", out_shape=out.shape, expected=x.shape) + + +@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat)) +def test_ifftn(x): + out = xp.fft.ifftn(x) + assert_fft_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("ifftn", out_shape=out.shape, expected=x.shape) From 22f9815aa1f1cad47285c4de636be2ff23c79152 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 6 Sep 2023 14:36:23 +0100 Subject: [PATCH 03/13] Move `all_floating_dtypes()` into `hypothesis_helpers.py` And use it in `test_fft.py` --- array_api_tests/hypothesis_helpers.py | 9 +++- array_api_tests/test_fft.py | 15 ++++--- ...est_operators_and_elementwise_functions.py | 45 ++++++++----------- 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 31f1e153..c4235ba1 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -11,7 +11,7 @@ integers, just, lists, none, one_of, sampled_from, shared) -from . import _array_module as xp +from . import _array_module as xp, api_version from . import dtype_helpers as dh from . import shape_helpers as sh from . import xps @@ -141,6 +141,13 @@ def oneway_broadcastable_shapes(draw) -> SearchStrategy[OnewayBroadcastableShape return OnewayBroadcastableShapes(input_shape, result_shape) +def all_floating_dtypes() -> SearchStrategy[DataType]: + strat = xps.floating_dtypes() + if api_version >= "2022.12": + strat |= xps.complex_dtypes() + return strat + + # shared() allows us to draw either the function or the function name and they # will both correspond to the same function. diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py index 823d3c3d..b7a82589 100644 --- a/array_api_tests/test_fft.py +++ b/array_api_tests/test_fft.py @@ -6,6 +6,7 @@ from array_api_tests.typing import DataType from . import _array_module as xp +from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps @@ -23,36 +24,38 @@ def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType): if in_dtype == xp.float32: expected = xp.complex64 - else: - assert in_dtype == xp.float64 # sanity check + elif in_dtype == xp.float64: expected = xp.complex128 + else: + assert dh.is_float_dtype(in_dtype) # sanity check + expected = in_dtype ph.assert_dtype( func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected ) -@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat)) +@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat)) def test_fft(x): out = xp.fft.fft(x) assert_fft_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("fft", out_shape=out.shape, expected=x.shape) -@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat)) +@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat)) def test_ifft(x): out = xp.fft.ifft(x) assert_fft_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("ifft", out_shape=out.shape, expected=x.shape) -@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat)) +@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat)) def test_fftn(x): out = xp.fft.fftn(x) assert_fft_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("fftn", out_shape=out.shape, expected=x.shape) -@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat)) +@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat)) def test_ifftn(x): out = xp.fft.ifftn(x) assert_fft_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 4d803bb0..39905456 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -33,13 +33,6 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: return xps.boolean_dtypes() | all_integer_dtypes() -def all_floating_dtypes() -> st.SearchStrategy[DataType]: - strat = xps.floating_dtypes() - if api_version >= "2022.12": - strat |= xps.complex_dtypes() - return strat - - def mock_int_dtype(n: int, dtype: DataType) -> int: """Returns equivalent of `n` that mocks `dtype` behaviour.""" nbits = dh.dtype_nbits[dtype] @@ -714,7 +707,7 @@ def test_abs(ctx, data): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_acos(x): out = xp.acos(x) ph.assert_dtype("acos", in_dtype=x.dtype, out_dtype=out.dtype) @@ -724,7 +717,7 @@ def test_acos(x): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_acosh(x): out = xp.acosh(x) ph.assert_dtype("acosh", in_dtype=x.dtype, out_dtype=out.dtype) @@ -748,7 +741,7 @@ def test_add(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_asin(x): out = xp.asin(x) ph.assert_dtype("asin", in_dtype=x.dtype, out_dtype=out.dtype) @@ -758,7 +751,7 @@ def test_asin(x): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_asinh(x): out = xp.asinh(x) ph.assert_dtype("asinh", in_dtype=x.dtype, out_dtype=out.dtype) @@ -766,7 +759,7 @@ def test_asinh(x): unary_assert_against_refimpl("asinh", x, out, math.asinh) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_atan(x): out = xp.atan(x) ph.assert_dtype("atan", in_dtype=x.dtype, out_dtype=out.dtype) @@ -782,7 +775,7 @@ def test_atan2(x1, x2): binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_atanh(x): out = xp.atanh(x) ph.assert_dtype("atanh", in_dtype=x.dtype, out_dtype=out.dtype) @@ -932,7 +925,7 @@ def test_conj(x): unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_cos(x): out = xp.cos(x) ph.assert_dtype("cos", in_dtype=x.dtype, out_dtype=out.dtype) @@ -940,7 +933,7 @@ def test_cos(x): unary_assert_against_refimpl("cos", x, out, math.cos) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_cosh(x): out = xp.cosh(x) ph.assert_dtype("cosh", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1001,7 +994,7 @@ def test_equal(ctx, data): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_exp(x): out = xp.exp(x) ph.assert_dtype("exp", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1009,7 +1002,7 @@ def test_exp(x): unary_assert_against_refimpl("exp", x, out, math.exp) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_expm1(x): out = xp.expm1(x) ph.assert_dtype("expm1", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1158,7 +1151,7 @@ def test_less_equal(ctx, data): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log(x): out = xp.log(x) ph.assert_dtype("log", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1168,7 +1161,7 @@ def test_log(x): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log1p(x): out = xp.log1p(x) ph.assert_dtype("log1p", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1178,7 +1171,7 @@ def test_log1p(x): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log2(x): out = xp.log2(x) ph.assert_dtype("log2", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1188,7 +1181,7 @@ def test_log2(x): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log10(x): out = xp.log10(x) ph.assert_dtype("log10", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1379,7 +1372,7 @@ def test_sign(x): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sin(x): out = xp.sin(x) ph.assert_dtype("sin", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1387,7 +1380,7 @@ def test_sin(x): unary_assert_against_refimpl("sin", x, out, math.sin) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sinh(x): out = xp.sinh(x) ph.assert_dtype("sinh", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1405,7 +1398,7 @@ def test_square(x): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sqrt(x): out = xp.sqrt(x) ph.assert_dtype("sqrt", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1429,7 +1422,7 @@ def test_subtract(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "-", operator.sub) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_tan(x): out = xp.tan(x) ph.assert_dtype("tan", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1437,7 +1430,7 @@ def test_tan(x): unary_assert_against_refimpl("tan", x, out, math.tan) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_tanh(x): out = xp.tanh(x) ph.assert_dtype("tanh", in_dtype=x.dtype, out_dtype=out.dtype) From 97904669bf25012e3261516d11d258d2434b28f2 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 13 Sep 2023 16:37:52 +0100 Subject: [PATCH 04/13] Test `ifft`, `rfft` and `irfft` --- array_api_tests/test_fft.py | 109 +++++++++++++++++++++++++++++------- 1 file changed, 89 insertions(+), 20 deletions(-) diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py index b7a82589..68d3cf72 100644 --- a/array_api_tests/test_fft.py +++ b/array_api_tests/test_fft.py @@ -1,15 +1,17 @@ import math +from typing import Optional import pytest from hypothesis import given +from hypothesis import strategies as st -from array_api_tests.typing import DataType +from array_api_tests.typing import Array, DataType -from . import _array_module as xp from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps +from ._array_module import mod as xp pytestmark = [ pytest.mark.ci, @@ -21,6 +23,22 @@ fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1) +def n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple: + size = math.prod(x.shape) + n = data.draw(st.none() | st.integers(size // 2, size * 2), label="n") + axis = data.draw(st.integers(-1, x.ndim - 1), label="axis") + norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm") + kwargs = data.draw( + hh.specified_kwargs( + ("n", n, None), + ("axis", axis, -1), + ("norm", norm, "backward"), + ), + label="kwargs", + ) + return n, axis, norm, kwargs + + def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType): if in_dtype == xp.float32: expected = xp.complex64 @@ -34,29 +52,80 @@ def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType) ) -@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat)) -def test_fft(x): - out = xp.fft.fft(x) +def assert_n_axis_shape( + func_name: str, *, x: Array, n: Optional[int], axis: int, out: Array +): + if n is None: + expected_shape = x.shape + else: + _axis = len(x.shape) - 1 if axis == -1 else axis + expected_shape = x.shape[:_axis] + (n,) + x.shape[_axis + 1 :] + ph.assert_shape(func_name, out_shape=out.shape, expected=expected_shape) + + +@given( + x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_fft(x, data): + n, axis, norm, kwargs = n_axis_norm_kwargs(x, data) + + out = xp.fft.fft(x, **kwargs) + assert_fft_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("fft", out_shape=out.shape, expected=x.shape) + assert_n_axis_shape("fft", x=x, n=n, axis=axis, out=out) + +@given( + x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_ifft(x, data): + n, axis, norm, kwargs = n_axis_norm_kwargs(x, data) + + out = xp.fft.ifft(x, **kwargs) -@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat)) -def test_ifft(x): - out = xp.fft.ifft(x) assert_fft_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("ifft", out_shape=out.shape, expected=x.shape) + assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out) + + +# TODO: +# test_fftn +# test_ifftn + + +@given( + x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_rfft(x, data): + n, axis, norm, kwargs = n_axis_norm_kwargs(x, data) + + out = xp.fft.rfft(x, **kwargs) + + assert_fft_dtype("rfft", in_dtype=x.dtype, out_dtype=out.dtype) + assert_n_axis_shape("rfft", x=x, n=n, axis=axis, out=out) + + +@given( + x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_irfft(x, data): + n, axis, norm, kwargs = n_axis_norm_kwargs(x, data) + out = xp.fft.irfft(x, **kwargs) -@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat)) -def test_fftn(x): - out = xp.fft.fftn(x) - assert_fft_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("fftn", out_shape=out.shape, expected=x.shape) + assert_fft_dtype("irfft", in_dtype=x.dtype, out_dtype=out.dtype) + # TODO: assert shape -@given(x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat)) -def test_ifftn(x): - out = xp.fft.ifftn(x) - assert_fft_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("ifftn", out_shape=out.shape, expected=x.shape) +# TODO: +# test_rfftn +# test_irfftn +# test_hfft +# test_ihfft +# fftfreq +# rfftfreq +# fftshift +# ifftshift From 952b9c35010e13b13819d65392c1938f01859105 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 6 Oct 2023 15:42:01 +0100 Subject: [PATCH 05/13] Implement `test_fftn` and `test_ifftn` --- array_api_tests/shape_helpers.py | 6 ++- array_api_tests/test_fft.py | 88 ++++++++++++++++++++++++++++---- 2 files changed, 82 insertions(+), 12 deletions(-) diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py index ba7d994e..6a0bdfde 100644 --- a/array_api_tests/shape_helpers.py +++ b/array_api_tests/shape_helpers.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import Iterator, List, Optional, Tuple, Union +from typing import Iterator, List, Optional, Sequence, Tuple, Union from ndindex import iter_indices as _iter_indices @@ -66,10 +66,12 @@ def broadcast_shapes(*shapes: Shape): def normalise_axis( - axis: Optional[Union[int, Tuple[int, ...]]], ndim: int + axis: Optional[Union[int, Sequence[int]]], ndim: int ) -> Tuple[int, ...]: if axis is None: return tuple(range(ndim)) + elif isinstance(axis, Sequence) and not isinstance(axis, tuple): + axis = tuple(axis) axes = axis if isinstance(axis, tuple) else (axis,) axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes) return axes diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py index 68d3cf72..c230bdd6 100644 --- a/array_api_tests/test_fft.py +++ b/array_api_tests/test_fft.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import List, Optional import pytest from hypothesis import given @@ -10,6 +10,7 @@ from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph +from . import shape_helpers as sh from . import xps from ._array_module import mod as xp @@ -23,9 +24,9 @@ fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1) -def n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple: +def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple: size = math.prod(x.shape) - n = data.draw(st.none() | st.integers(size // 2, size * 2), label="n") + n = data.draw(st.none() | st.integers((size // 2), math.ceil(size * 1.5)), label="n") axis = data.draw(st.integers(-1, x.ndim - 1), label="axis") norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm") kwargs = data.draw( @@ -39,6 +40,32 @@ def n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple: return n, axis, norm, kwargs +def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject) -> tuple: + all_axes = list(range(x.ndim)) + axes = data.draw( + st.none() | st.lists(st.sampled_from(all_axes), min_size=1, unique=True), + label="axes", + ) + _axes = all_axes if axes is None else axes + axes_sides = [x.shape[axis] for axis in _axes] + s_strat = st.tuples( + *[st.integers(max(side // 2, 1), math.ceil(side * 1.5)) for side in axes_sides] + ) + if axes is None: + s_strat = st.none() | s_strat + s = data.draw(s_strat, label="s") + norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm") + kwargs = data.draw( + hh.specified_kwargs( + ("s", s, None), + ("axes", axes, None), + ("norm", norm, "backward"), + ), + label="kwargs", + ) + return s, axes, norm, kwargs + + def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType): if in_dtype == xp.float32: expected = xp.complex64 @@ -63,12 +90,32 @@ def assert_n_axis_shape( ph.assert_shape(func_name, out_shape=out.shape, expected=expected_shape) +def assert_s_axes_shape( + func_name: str, + *, + x: Array, + s: Optional[List[int]], + axes: Optional[List[int]], + out: Array, +): + _axes = sh.normalise_axis(axes, x.ndim) + _s = x.shape if s is None else s + expected = [] + for i in range(x.ndim): + if i in _axes: + side = _s[_axes.index(i)] + else: + side = x.shape[i] + expected.append(side) + ph.assert_shape(func_name, out_shape=out.shape, expected=tuple(expected)) + + @given( x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), data=st.data(), ) def test_fft(x, data): - n, axis, norm, kwargs = n_axis_norm_kwargs(x, data) + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) out = xp.fft.fft(x, **kwargs) @@ -81,7 +128,7 @@ def test_fft(x, data): data=st.data(), ) def test_ifft(x, data): - n, axis, norm, kwargs = n_axis_norm_kwargs(x, data) + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) out = xp.fft.ifft(x, **kwargs) @@ -89,9 +136,30 @@ def test_ifft(x, data): assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out) -# TODO: -# test_fftn -# test_ifftn +@given( + x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_fftn(x, data): + s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data) + + out = xp.fft.fftn(x, **kwargs) + + assert_fft_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype) + assert_s_axes_shape("fftn", x=x, s=s, axes=axes, out=out) + + +@given( + x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_ifftn(x, data): + s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data) + + out = xp.fft.ifftn(x, **kwargs) + + assert_fft_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype) + assert_s_axes_shape("ifftn", x=x, s=s, axes=axes, out=out) @given( @@ -99,7 +167,7 @@ def test_ifft(x, data): data=st.data(), ) def test_rfft(x, data): - n, axis, norm, kwargs = n_axis_norm_kwargs(x, data) + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) out = xp.fft.rfft(x, **kwargs) @@ -112,7 +180,7 @@ def test_rfft(x, data): data=st.data(), ) def test_irfft(x, data): - n, axis, norm, kwargs = n_axis_norm_kwargs(x, data) + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) out = xp.fft.irfft(x, **kwargs) From d0d60fc3a778a702a7e51d507d8a9e3b70cc6f20 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 6 Oct 2023 15:51:24 +0100 Subject: [PATCH 06/13] Hide traceback of assert helpers --- array_api_tests/pytest_helpers.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index c51b14a6..f411ba71 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -1,7 +1,8 @@ import cmath import math +from functools import wraps from inspect import getfullargspec -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union from . import _array_module as xp from . import dtype_helpers as dh @@ -122,6 +123,7 @@ def assert_dtype( >>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int) """ + __tracebackhide__ = True in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype] f_in_dtypes = dh.fmt_types(tuple(in_dtypes)) f_out_dtype = dh.dtype_to_name[out_dtype] @@ -149,6 +151,7 @@ def assert_kw_dtype( >>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype) """ + __tracebackhide__ = True f_kw_dtype = dh.dtype_to_name[kw_dtype] f_out_dtype = dh.dtype_to_name[out_dtype] msg = ( @@ -166,6 +169,7 @@ def assert_default_float(func_name: str, out_dtype: DataType): >>> assert_default_float('ones', out.dtype) """ + __tracebackhide__ = True f_dtype = dh.dtype_to_name[out_dtype] f_default = dh.dtype_to_name[dh.default_float] msg = ( @@ -183,6 +187,7 @@ def assert_default_complex(func_name: str, out_dtype: DataType): >>> assert_default_complex('asarray', out.dtype) """ + __tracebackhide__ = True f_dtype = dh.dtype_to_name[out_dtype] f_default = dh.dtype_to_name[dh.default_complex] msg = ( @@ -200,6 +205,7 @@ def assert_default_int(func_name: str, out_dtype: DataType): >>> assert_default_int('full', out.dtype) """ + __tracebackhide__ = True f_dtype = dh.dtype_to_name[out_dtype] f_default = dh.dtype_to_name[dh.default_int] msg = ( @@ -217,6 +223,7 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty >>> assert_default_int('argmax', out.dtype) """ + __tracebackhide__ = True f_dtype = dh.dtype_to_name[out_dtype] msg = ( f"{repr_name}={f_dtype}, should be the default index dtype, " @@ -240,6 +247,7 @@ def assert_shape( >>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3)) """ + __tracebackhide__ = True if isinstance(out_shape, int): out_shape = (out_shape,) if isinstance(expected, int): @@ -273,6 +281,7 @@ def assert_result_shape( >>> assert out.shape == (3, 3) """ + __tracebackhide__ = True if expected is None: expected = sh.broadcast_shapes(*in_shapes) f_in_shapes = " . ".join(str(s) for s in in_shapes) @@ -307,6 +316,7 @@ def assert_keepdimable_shape( >>> assert out2.shape == (1, 1) """ + __tracebackhide__ = True if keepdims: shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape)) else: @@ -337,6 +347,7 @@ def assert_0d_equals( >>> assert res[0] == x[0] """ + __tracebackhide__ = True msg = ( f"{out_repr}={out_val}, but should be {x_repr}={x_val} " f"[{func_name}({fmt_kw(kw)})]" @@ -369,6 +380,7 @@ def assert_scalar_equals( >>> assert int(out) == 5 """ + __tracebackhide__ = True repr_name = repr_name if idx == () else f"{repr_name}[{idx}]" f_func = f"{func_name}({fmt_kw(kw)})" if type_ in [bool, int]: @@ -401,6 +413,7 @@ def assert_fill( >>> assert xp.all(out == 42) """ + __tracebackhide__ = True msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}" if cmath.isnan(fill_value): assert xp.all(xp.isnan(out)), msg @@ -443,6 +456,7 @@ def assert_array_elements( >>> assert xp.all(out == x) """ + __tracebackhide__ = True dh.result_type(out.dtype, expected.dtype) # sanity check assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check f_func = f"[{func_name}({fmt_kw(kw)})]" @@ -470,3 +484,18 @@ def assert_array_elements( assert xp.all( out == expected ), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}" + + +def _make_wrapped_assert_helper(assert_helper: Callable) -> Callable: + @wraps(assert_helper) + def wrapped_assert_helper(*args, **kwargs): + __tracebackhide__ = True + assert_helper(*args, **kwargs) + + return wrapped_assert_helper + + +for func_name in __all__: + if func_name.startswith("assert"): + assert_helper = globals()[func_name] + globals()[func_name] = _make_wrapped_assert_helper(assert_helper) From 671e07e9575802a6ac460165486e1494895a2d96 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 6 Oct 2023 18:31:54 +0100 Subject: [PATCH 07/13] Add tests for (i)rfftn and (i)hfft And rudimentary support for `2*(m-1)` size rules --- array_api_tests/test_fft.py | 84 +++++++++++++++++++++++++++++++------ 1 file changed, 72 insertions(+), 12 deletions(-) diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py index c230bdd6..57a25248 100644 --- a/array_api_tests/test_fft.py +++ b/array_api_tests/test_fft.py @@ -2,7 +2,7 @@ from typing import List, Optional import pytest -from hypothesis import given +from hypothesis import assume, given from hypothesis import strategies as st from array_api_tests.typing import Array, DataType @@ -24,10 +24,15 @@ fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1) -def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple: +def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple: size = math.prod(x.shape) - n = data.draw(st.none() | st.integers((size // 2), math.ceil(size * 1.5)), label="n") + n = data.draw( + st.none() | st.integers((size // 2), math.ceil(size * 1.5)), label="n" + ) axis = data.draw(st.integers(-1, x.ndim - 1), label="axis") + if size_gt_1: + _axis = x.ndim - 1 if axis == -1 else axis + assume(x.shape[_axis] > 1) norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm") kwargs = data.draw( hh.specified_kwargs( @@ -40,7 +45,7 @@ def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple: return n, axis, norm, kwargs -def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject) -> tuple: +def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple: all_axes = list(range(x.ndim)) axes = data.draw( st.none() | st.lists(st.sampled_from(all_axes), min_size=1, unique=True), @@ -54,6 +59,14 @@ def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject) -> tuple: if axes is None: s_strat = st.none() | s_strat s = data.draw(s_strat, label="s") + if size_gt_1: + _s = x.shape if s is None else s + for i in range(x.ndim): + if i in _axes: + side = _s[_axes.index(i)] + else: + side = x.shape[i] + assume(side > 1) norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm") kwargs = data.draw( hh.specified_kwargs( @@ -163,7 +176,7 @@ def test_ifftn(x, data): @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat), + x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), data=st.data(), ) def test_rfft(x, data): @@ -176,11 +189,11 @@ def test_rfft(x, data): @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat), + x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), data=st.data(), ) def test_irfft(x, data): - n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True) out = xp.fft.irfft(x, **kwargs) @@ -188,11 +201,58 @@ def test_irfft(x, data): # TODO: assert shape -# TODO: -# test_rfftn -# test_irfftn -# test_hfft -# test_ihfft +@given( + x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_rfftn(x, data): + s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data) + + out = xp.fft.rfftn(x, **kwargs) + + assert_fft_dtype("rfftn", in_dtype=x.dtype, out_dtype=out.dtype) + assert_s_axes_shape("rfftn", x=x, s=s, axes=axes, out=out) + + +@given( + x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_irfftn(x, data): + s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data, size_gt_1=True) + + out = xp.fft.irfftn(x, **kwargs) + + assert_fft_dtype("irfftn", in_dtype=x.dtype, out_dtype=out.dtype) + assert_s_axes_shape("irfftn", x=x, s=s, axes=axes, out=out) + + +@given( + x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_hfft(x, data): + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True) + + out = xp.fft.hfft(x, **kwargs) + + assert_fft_dtype("hfft", in_dtype=x.dtype, out_dtype=out.dtype) + # TODO: shape + + +@given( + x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_ihfft(x, data): + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) + + out = xp.fft.ihfft(x, **kwargs) + + assert_fft_dtype("ihfft", in_dtype=x.dtype, out_dtype=out.dtype) + # TODO: shape + + # fftfreq # rfftfreq # fftshift From cd803b8006dc6b6dc4a76de206f4b4955a86c162 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 23 Oct 2023 11:55:25 +0100 Subject: [PATCH 08/13] `size_gt_1` testing in `assert_n_axis_shape()` --- array_api_tests/test_fft.py | 43 +++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py index 57a25248..f9a94a50 100644 --- a/array_api_tests/test_fft.py +++ b/array_api_tests/test_fft.py @@ -93,14 +93,24 @@ def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType) def assert_n_axis_shape( - func_name: str, *, x: Array, n: Optional[int], axis: int, out: Array + func_name: str, + *, + x: Array, + n: Optional[int], + axis: int, + out: Array, + size_gt_1=False, ): + _axis = len(x.shape) - 1 if axis == -1 else axis if n is None: - expected_shape = x.shape + if size_gt_1: + axis_side = 2 * (x.shape[_axis] - 1) + else: + axis_side = x.shape[_axis] else: - _axis = len(x.shape) - 1 if axis == -1 else axis - expected_shape = x.shape[:_axis] + (n,) + x.shape[_axis + 1 :] - ph.assert_shape(func_name, out_shape=out.shape, expected=expected_shape) + axis_side = n + expected = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :] + ph.assert_shape(func_name, out_shape=out.shape, expected=expected) def assert_s_axes_shape( @@ -198,7 +208,14 @@ def test_irfft(x, data): out = xp.fft.irfft(x, **kwargs) assert_fft_dtype("irfft", in_dtype=x.dtype, out_dtype=out.dtype) - # TODO: assert shape + + _axis = x.ndim - 1 if axis == -1 else axis + if n is None: + axis_side = 2 * (x.shape[_axis] - 1) + else: + axis_side = n + expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :] + ph.assert_shape("irfft", out_shape=out.shape, expected=expected_shape) @given( @@ -224,7 +241,7 @@ def test_irfftn(x, data): out = xp.fft.irfftn(x, **kwargs) assert_fft_dtype("irfftn", in_dtype=x.dtype, out_dtype=out.dtype) - assert_s_axes_shape("irfftn", x=x, s=s, axes=axes, out=out) + # TODO: shape @given( @@ -237,7 +254,14 @@ def test_hfft(x, data): out = xp.fft.hfft(x, **kwargs) assert_fft_dtype("hfft", in_dtype=x.dtype, out_dtype=out.dtype) - # TODO: shape + + _axis = x.ndim - 1 if axis == -1 else axis + if n is None: + axis_side = 2 * (x.shape[_axis] - 1) + else: + axis_side = n + expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :] + ph.assert_shape("hfft", out_shape=out.shape, expected=expected_shape) @given( @@ -250,9 +274,10 @@ def test_ihfft(x, data): out = xp.fft.ihfft(x, **kwargs) assert_fft_dtype("ihfft", in_dtype=x.dtype, out_dtype=out.dtype) - # TODO: shape + assert_n_axis_shape("ihfft", x=x, n=n, axis=axis, out=out, size_gt_1=True) +# TODO: # fftfreq # rfftfreq # fftshift From ef95ba1887d019144d9d694a95d4ac33f19dc94b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 23 Oct 2023 13:39:55 +0100 Subject: [PATCH 09/13] Valid shapes for `test_irfftn` --- array_api_tests/test_fft.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py index f9a94a50..4634fd5c 100644 --- a/array_api_tests/test_fft.py +++ b/array_api_tests/test_fft.py @@ -232,7 +232,9 @@ def test_rfftn(x, data): @given( - x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), + x=xps.arrays( + dtype=xps.complex_dtypes(), shape=fft_shapes_strat.filter(lambda s: s[-1] > 1) + ), data=st.data(), ) def test_irfftn(x, data): From 10b4683108d68b712028b8724c4630721493477d Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 23 Oct 2023 18:14:04 +0100 Subject: [PATCH 10/13] `size_gt_1` support for `assert_s_axes_shape()` --- array_api_tests/test_fft.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py index 4634fd5c..52645ba9 100644 --- a/array_api_tests/test_fft.py +++ b/array_api_tests/test_fft.py @@ -99,7 +99,7 @@ def assert_n_axis_shape( n: Optional[int], axis: int, out: Array, - size_gt_1=False, + size_gt_1: bool = False, ): _axis = len(x.shape) - 1 if axis == -1 else axis if n is None: @@ -120,6 +120,7 @@ def assert_s_axes_shape( s: Optional[List[int]], axes: Optional[List[int]], out: Array, + size_gt_1: bool = False, ): _axes = sh.normalise_axis(axes, x.ndim) _s = x.shape if s is None else s @@ -130,6 +131,10 @@ def assert_s_axes_shape( else: side = x.shape[i] expected.append(side) + if size_gt_1: + last_axis = _axes[-1] + expected[last_axis] = 2 * (expected[last_axis] - 1) + assume(expected[last_axis] > 0) # TODO: generate valid examples ph.assert_shape(func_name, out_shape=out.shape, expected=tuple(expected)) @@ -243,7 +248,7 @@ def test_irfftn(x, data): out = xp.fft.irfftn(x, **kwargs) assert_fft_dtype("irfftn", in_dtype=x.dtype, out_dtype=out.dtype) - # TODO: shape + assert_s_axes_shape("rfftn", x=x, s=s, axes=axes, out=out, size_gt_1=True) @given( From 21954fc47401121dd94ac515e2f0c6ac3cf4f575 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 10 Nov 2023 15:40:09 +0000 Subject: [PATCH 11/13] Fix min version for `test_fft.py` --- array_api_tests/test_fft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py index 52645ba9..e01caa80 100644 --- a/array_api_tests/test_fft.py +++ b/array_api_tests/test_fft.py @@ -17,7 +17,7 @@ pytestmark = [ pytest.mark.ci, pytest.mark.xp_extension("fft"), - pytest.mark.min_version("draft"), + pytest.mark.min_version("2022.12"), ] From 6998b0151d7ee358e0315a437fc5998425a1388f Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 10 Nov 2023 15:45:40 +0000 Subject: [PATCH 12/13] Remove unnecessary `__tracebackhide__`-decorating magic --- array_api_tests/pytest_helpers.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index f411ba71..e6ede7b2 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -1,8 +1,7 @@ import cmath import math -from functools import wraps from inspect import getfullargspec -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union from . import _array_module as xp from . import dtype_helpers as dh @@ -484,18 +483,3 @@ def assert_array_elements( assert xp.all( out == expected ), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}" - - -def _make_wrapped_assert_helper(assert_helper: Callable) -> Callable: - @wraps(assert_helper) - def wrapped_assert_helper(*args, **kwargs): - __tracebackhide__ = True - assert_helper(*args, **kwargs) - - return wrapped_assert_helper - - -for func_name in __all__: - if func_name.startswith("assert"): - assert_helper = globals()[func_name] - globals()[func_name] = _make_wrapped_assert_helper(assert_helper) From dc2d4b9c2d070dd4379e813dfa0e6cd50933c344 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 10 Nov 2023 16:15:08 +0000 Subject: [PATCH 13/13] Hack to allow top-level `xps.complex_dtypes()` for `2012.12` --- array_api_tests/test_fft.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py index e01caa80..7dc70d56 100644 --- a/array_api_tests/test_fft.py +++ b/array_api_tests/test_fft.py @@ -1,5 +1,6 @@ import math from typing import List, Optional +from unittest.mock import MagicMock import pytest from hypothesis import assume, given @@ -7,6 +8,7 @@ from array_api_tests.typing import Array, DataType +from . import api_version from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph @@ -21,6 +23,11 @@ ] +# Using xps.complex_dtypes() raises an AttributeError for 2021.12 instances of +# xps, hence this hack. TODO: figure out a better way to manage this! +if api_version < "2022.12": + xps = MagicMock(xps) + fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1)