diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 9a09ffd7..c9f44886 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -16,4 +16,4 @@ jobs: pip install ruff # Update output format to enable automatic inline annotations. - name: Run Ruff - run: ruff check --output-format=github --select F822,PLC0414,RUF022 --preview . + run: ruff check --output-format=github . diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 9073dd52..170a1ff9 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -5,7 +5,6 @@ from functools import wraps from inspect import signature - def get_xp(xp): """ Decorator to automatically replace xp with the corresponding array module. @@ -45,31 +44,3 @@ def wrapped_f(*args, **kwargs): return wrapped_f return inner - - -def _get_all_public_members(module, exclude=None, extend_all=False): - """Get all public members of a module. - - Parameters - ---------- - module : module - The module to get members from. - exclude : callable, optional - A callable that takes a name and returns True if the name should be - excluded from the list of members. - extend_all : bool, optional - If True, extend the module's __all__ attribute with the members of the - module derived from dir(module). To be used for libraries that do not have a complete __all__ list. - """ - members = getattr(module, "__all__", []) - - if members and not extend_all: - return members - - if exclude is None: - exclude = lambda name: name.startswith("_") # noqa: E731 - - members = members + [_ for _ in dir(module) if not exclude(_)] - - # remove duplicates - return list(set(members)) diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py index 3317899b..91ab1c40 100644 --- a/array_api_compat/common/__init__.py +++ b/array_api_compat/common/__init__.py @@ -1,27 +1 @@ -from ._helpers import ( - array_namespace, - device, - get_namespace, - is_array_api_obj, - is_cupy_array, - is_dask_array, - is_jax_array, - is_numpy_array, - is_torch_array, - size, - to_device, -) - -__all__ = [ - "array_namespace", - "device", - "get_namespace", - "is_array_api_obj", - "is_cupy_array", - "is_dask_array", - "is_jax_array", - "is_numpy_array", - "is_torch_array", - "size", - "to_device", -] +from ._helpers import * # noqa: F403 diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index a184a0c3..167c7c0f 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -146,6 +146,9 @@ def zeros_like( # The functions here return namedtuples (np.unique() returns a normal # tuple). + +# Note that these named tuples aren't actually part of the standard namespace, +# but I don't see any issue with exporting the names here regardless. class UniqueAllResult(NamedTuple): values: ndarray indices: ndarray @@ -545,3 +548,11 @@ def isdtype( # more strict here to match the type annotation? Note that the # array_api_strict implementation will be very strict. return dtype == kind + +__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', + 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', + 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', + 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', + 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort', + 'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc', + 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 5e59c7ea..2fa963f8 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -288,3 +288,19 @@ def size(x): if None in x.shape: return None return math.prod(x.shape) + +__all__ = [ + "array_namespace", + "device", + "get_namespace", + "is_array_api_obj", + "is_cupy_array", + "is_dask_array", + "is_jax_array", + "is_numpy_array", + "is_torch_array", + "size", + "to_device", +] + +_all_ignore = ['sys', 'math', 'inspect'] diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 0708b76a..3b17417d 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -11,7 +11,7 @@ else: from numpy.core.numeric import normalize_axis_tuple -from ._aliases import matrix_transpose, isdtype +from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp # These are in the main NumPy namespace but not in numpy.linalg @@ -149,4 +149,10 @@ def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarra dtype = xp.float64 elif x.dtype == xp.complex64: dtype = xp.complex128 - return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) \ No newline at end of file + return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) + +__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', + 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', + 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm', + 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', + 'trace'] diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index f20d085c..07f3850d 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -20,4 +20,4 @@ def __len__(self, /) -> int: ... SupportsBufferProtocol = Any Array = Any -Device = Any \ No newline at end of file +Device = Any diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index b5eb5eea..148691f5 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -1,153 +1,14 @@ -import cupy as _cp -from cupy import * # noqa: F401, F403 +from cupy import * # noqa: F403 # from cupy import * doesn't overwrite these builtin names -from cupy import abs, max, min, round - -from .._internal import _get_all_public_members -from ..common._helpers import ( - array_namespace, - device, - get_namespace, - is_array_api_obj, - size, - to_device, -) +from cupy import abs, max, min, round # noqa: F401 # These imports may overwrite names from the import * above. -from ._aliases import ( - UniqueAllResult, - UniqueCountsResult, - UniqueInverseResult, - acos, - acosh, - arange, - argsort, - asarray, - asarray_cupy, - asin, - asinh, - astype, - atan, - atan2, - atanh, - bitwise_invert, - bitwise_left_shift, - bitwise_right_shift, - bool, - ceil, - concat, - empty, - empty_like, - eye, - floor, - full, - full_like, - isdtype, - linspace, - matmul, - matrix_transpose, - nonzero, - ones, - ones_like, - permute_dims, - pow, - prod, - reshape, - sort, - std, - sum, - tensordot, - trunc, - unique_all, - unique_counts, - unique_inverse, - unique_values, - var, - vecdot, - zeros, - zeros_like, -) - -__all__ = [] - -__all__ += _get_all_public_members(_cp) - -__all__ += [ - "abs", - "max", - "min", - "round", -] - -__all__ += [ - "array_namespace", - "device", - "get_namespace", - "is_array_api_obj", - "size", - "to_device", -] - -__all__ += [ - "UniqueAllResult", - "UniqueCountsResult", - "UniqueInverseResult", - "acos", - "acosh", - "arange", - "argsort", - "asarray", - "asarray_cupy", - "asin", - "asinh", - "astype", - "atan", - "atan2", - "atanh", - "bitwise_invert", - "bitwise_left_shift", - "bitwise_right_shift", - "bool", - "ceil", - "concat", - "empty", - "empty_like", - "eye", - "floor", - "full", - "full_like", - "isdtype", - "linspace", - "matmul", - "matrix_transpose", - "nonzero", - "ones", - "ones_like", - "permute_dims", - "pow", - "prod", - "reshape", - "sort", - "std", - "sum", - "tensordot", - "trunc", - "unique_all", - "unique_counts", - "unique_inverse", - "unique_values", - "var", - "zeros", - "zeros_like", -] - -__all__ += [ - "matrix_transpose", - "vecdot", -] +from ._aliases import * # noqa: F403 # See the comment in the numpy __init__.py -__import__(__package__ + ".linalg") +__import__(__package__ + '.linalg') + +from ..common._helpers import * # noqa: F401,F403 -__array_api_version__ = "2022.12" +__array_api_version__ = '2022.12' diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 71ffadce..968b974b 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -5,12 +5,11 @@ import cupy as cp from ..common import _aliases -from ..common import _linalg - from .._internal import get_xp asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy') asarray.__doc__ = _aliases._asarray.__doc__ +del partial bool = cp.bool_ @@ -74,28 +73,7 @@ else: isdtype = get_xp(cp)(_aliases.isdtype) - -cross = get_xp(cp)(_linalg.cross) -outer = get_xp(cp)(_linalg.outer) -EighResult = _linalg.EighResult -QRResult = _linalg.QRResult -SlogdetResult = _linalg.SlogdetResult -SVDResult = _linalg.SVDResult -eigh = get_xp(cp)(_linalg.eigh) -qr = get_xp(cp)(_linalg.qr) -slogdet = get_xp(cp)(_linalg.slogdet) -svd = get_xp(cp)(_linalg.svd) -cholesky = get_xp(cp)(_linalg.cholesky) -matrix_rank = get_xp(cp)(_linalg.matrix_rank) -pinv = get_xp(cp)(_linalg.pinv) -matrix_norm = get_xp(cp)(_linalg.matrix_norm) -svdvals = get_xp(cp)(_linalg.svdvals) -diagonal = get_xp(cp)(_linalg.diagonal) -trace = get_xp(cp)(_linalg.trace) - -# These functions are completely new here. If the library already has them -# (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(cp.linalg, 'vector_norm'): - vector_norm = cp.linalg.vector_norm -else: - vector_norm = get_xp(cp)(_linalg.vector_norm) \ No newline at end of file +__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos', + 'acosh', 'asin', 'asinh', 'atan', 'atan2', + 'atanh', 'bitwise_left_shift', 'bitwise_invert', + 'bitwise_right_shift', 'concat', 'pow'] diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index b867ca94..f3d9aab6 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,9 +1,9 @@ from __future__ import annotations __all__ = [ + "ndarray", "Device", "Dtype", - "ndarray", ] import sys diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index cef74183..7fcdd498 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -1,62 +1,49 @@ -import cupy as _cp +from cupy.linalg import * # noqa: F403 +# cupy.linalg doesn't have __all__. If it is added, replace this with +# +# from cupy.linalg import __all__ as linalg_all +_n = {} +exec('from cupy.linalg import *', _n) +del _n['__builtins__'] +linalg_all = list(_n) +del _n -from .._internal import _get_all_public_members +from ..common import _linalg +from .._internal import get_xp -_cupy_linalg_all = _get_all_public_members(_cp.linalg) +import cupy as cp -for _name in _cupy_linalg_all: - globals()[_name] = getattr(_cp.linalg, _name) +# These functions are in both the main and linalg namespaces +from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 -from ._aliases import ( # noqa: E402 - EighResult, - QRResult, - SlogdetResult, - SVDResult, - cholesky, - cross, - diagonal, - eigh, - matmul, - matrix_norm, - matrix_rank, - matrix_transpose, - outer, - pinv, - qr, - slogdet, - svd, - svdvals, - tensordot, - trace, - vecdot, - vector_norm, -) +cross = get_xp(cp)(_linalg.cross) +outer = get_xp(cp)(_linalg.outer) +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +eigh = get_xp(cp)(_linalg.eigh) +qr = get_xp(cp)(_linalg.qr) +slogdet = get_xp(cp)(_linalg.slogdet) +svd = get_xp(cp)(_linalg.svd) +cholesky = get_xp(cp)(_linalg.cholesky) +matrix_rank = get_xp(cp)(_linalg.matrix_rank) +pinv = get_xp(cp)(_linalg.pinv) +matrix_norm = get_xp(cp)(_linalg.matrix_norm) +svdvals = get_xp(cp)(_linalg.svdvals) +diagonal = get_xp(cp)(_linalg.diagonal) +trace = get_xp(cp)(_linalg.trace) -__all__ = [] +# These functions are completely new here. If the library already has them +# (i.e., numpy 2.0), use the library version instead of our wrapper. +if hasattr(cp.linalg, 'vector_norm'): + vector_norm = cp.linalg.vector_norm +else: + vector_norm = get_xp(cp)(_linalg.vector_norm) -__all__ += _cupy_linalg_all +__all__ = linalg_all + _linalg.__all__ -__all__ += [ - "EighResult", - "QRResult", - "SVDResult", - "SlogdetResult", - "cholesky", - "cross", - "diagonal", - "eigh", - "matmul", - "matrix_norm", - "matrix_rank", - "matrix_transpose", - "outer", - "pinv", - "qr", - "slogdet", - "svd", - "svdvals", - "tensordot", - "trace", - "vecdot", - "vector_norm", -] +del get_xp +del cp +del linalg_all +del _linalg diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index d6b5e94e..03e0cd72 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,210 +1,8 @@ -import dask.array as _da -from dask.array import * # noqa: F401, F403 -from dask.array import ( - # Element wise aliases - arccos as acos, -) -from dask.array import ( - arccosh as acosh, -) -from dask.array import ( - arcsin as asin, -) -from dask.array import ( - arcsinh as asinh, -) -from dask.array import ( - arctan as atan, -) -from dask.array import ( - arctan2 as atan2, -) -from dask.array import ( - arctanh as atanh, -) -from dask.array import ( - bool_ as bool, -) -from dask.array import ( - # Other - concatenate as concat, -) -from dask.array import ( - invert as bitwise_invert, -) -from dask.array import ( - left_shift as bitwise_left_shift, -) -from dask.array import ( - power as pow, -) -from dask.array import ( - right_shift as bitwise_right_shift, -) +from dask.array import * # noqa: F403 # These imports may overwrite names from the import * above. -from numpy import ( - can_cast, - complex64, - complex128, - e, - finfo, - float32, - float64, - iinfo, - inf, - int8, - int16, - int32, - int64, - nan, - newaxis, - pi, - result_type, - uint8, - uint16, - uint32, - uint64, -) +from ._aliases import * # noqa: F403 -from ..common._helpers import ( - array_namespace, - device, - get_namespace, - is_array_api_obj, - size, - to_device, -) -from ..internal import _get_all_public_members -from ._aliases import ( - UniqueAllResult, - UniqueCountsResult, - UniqueInverseResult, - arange, - asarray, - astype, - ceil, - empty, - empty_like, - eye, - floor, - full, - full_like, - isdtype, - linspace, - matmul, - matrix_transpose, - nonzero, - ones, - ones_like, - permute_dims, - prod, - reshape, - std, - sum, - tensordot, - trunc, - unique_all, - unique_counts, - unique_inverse, - unique_values, - var, - vecdot, - zeros, - zeros_like, -) +__array_api_version__ = '2022.12' -__all__ = [] - -__all__ += _get_all_public_members(_da) - -__all__ += [ - "can_cast", - "complex64", - "complex128", - "e", - "finfo", - "float32", - "float64", - "iinfo", - "inf", - "int8", - "int16", - "int32", - "int64", - "nan", - "newaxis", - "pi", - "result_type", - "uint8", - "uint16", - "uint32", - "uint64", -] - -__all__ += [ - "array_namespace", - "device", - "get_namespace", - "is_array_api_obj", - "size", - "to_device", -] - -# 'sort', 'argsort' are unsupported by dask.array - -__all__ += [ - "UniqueAllResult", - "UniqueCountsResult", - "UniqueInverseResult", - "acos", - "acosh", - "arange", - "asarray", - "asin", - "asinh", - "astype", - "atan", - "atan2", - "atanh", - "bitwise_invert", - "bitwise_left_shift", - "bitwise_right_shift", - "bool", - "ceil", - "concat", - "empty", - "empty_like", - "eye", - "floor", - "full", - "full_like", - "isdtype", - "linspace", - "matmul", - "matrix_transpose", - "nonzero", - "ones", - "ones_like", - "permute_dims", - "pow", - "prod", - "reshape", - "std", - "sum", - "tensordot", - "trunc", - "unique_all", - "unique_counts", - "unique_inverse", - "unique_values", - "var", - "vecdot", - "zeros", - "zeros_like", -] - - -__array_api_version__ = "2022.12" - -__import__(__package__ + ".linalg") +__import__(__package__ + '.linalg') diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 14b27070..844cdf91 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,18 +1,42 @@ from __future__ import annotations -from functools import partial -from typing import TYPE_CHECKING - -import numpy as np +from ...common import _aliases +from ...common._helpers import _check_device from ..._internal import get_xp -from ...common import _aliases, _linalg -from ...common._helpers import _check_device -if TYPE_CHECKING: - from typing import Optional, Tuple, Union +import numpy as np +from numpy import ( + # Constants + e, + inf, + nan, + pi, + newaxis, + # Dtypes + bool_ as bool, + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + complex64, + complex128, + iinfo, + finfo, + can_cast, + result_type, +) - from ...common._typing import Device, Dtype, ndarray +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Union + from ...common._typing import ndarray, Device, Dtype import dask.array as da @@ -25,9 +49,8 @@ # not pass stop/step as keyword arguments, which will cause # an error with dask - # TODO: delete the xp stuff, it shouldn't be necessary -def dask_arange( +def _dask_arange( start: Union[int, float], /, stop: Optional[Union[int, float]] = None, @@ -49,11 +72,11 @@ def dask_arange( args.append(step) return xp.arange(*args, dtype=dtype, **kwargs) - -arange = get_xp(da)(dask_arange) +arange = get_xp(da)(_dask_arange) eye = get_xp(da)(_aliases.eye) -asarray = partial(_aliases._asarray, namespace="dask.array") +from functools import partial +asarray = partial(_aliases._asarray, namespace='dask.array') asarray.__doc__ = _aliases._asarray.__doc__ linspace = get_xp(da)(_aliases.linspace) @@ -89,22 +112,34 @@ def dask_arange( matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) - -EighResult = _linalg.EighResult -QRResult = _linalg.QRResult -SlogdetResult = _linalg.SlogdetResult -SVDResult = _linalg.SVDResult -qr = get_xp(da)(_linalg.qr) -cholesky = get_xp(da)(_linalg.cholesky) -matrix_rank = get_xp(da)(_linalg.matrix_rank) -matrix_norm = get_xp(da)(_linalg.matrix_norm) - - -def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: - # TODO: can't avoid computing U or V for dask - _, s, _ = da.linalg.svd(x) - return s - - -vector_norm = get_xp(da)(_linalg.vector_norm) -diagonal = get_xp(da)(_linalg.diagonal) +from dask.array import ( + # Element wise aliases + arccos as acos, + arccosh as acosh, + arcsin as asin, + arcsinh as asinh, + arctan as atan, + arctan2 as atan2, + arctanh as atanh, + left_shift as bitwise_left_shift, + right_shift as bitwise_right_shift, + invert as bitwise_invert, + power as pow, + # Other + concatenate as concat, +) + +# exclude these from all since +_da_unsupported = ['sort', 'argsort'] + +common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported] + +__all__ = common_aliases + ['asarray', 'bool', 'acos', + 'acosh', 'asin', 'asinh', 'atan', 'atan2', + 'atanh', 'bitwise_left_shift', 'bitwise_invert', + 'bitwise_right_shift', 'concat', 'pow', + 'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8', + 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', + 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type'] + +_all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np'] diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index cc9ac880..f2dd80cd 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -1,50 +1,56 @@ -import dask.array as _da -from dask.array import ( - matmul, - outer, - tensordot, - trace, -) -from dask.array.linalg import * # noqa: F401, F403 - -from .._internal import _get_all_public_members -from ._aliases import ( - EighResult, - QRResult, - SlogdetResult, - SVDResult, - cholesky, - diagonal, - matrix_norm, - matrix_rank, - matrix_transpose, - qr, - svdvals, - vecdot, - vector_norm, -) - -__all__ = [ - "matmul", - "outer", - "tensordot", - "trace", -] - -__all__ += _get_all_public_members(_da.linalg) - -__all__ += [ - "EighResult", - "QRResult", - "SVDResult", - "SlogdetResult", - "cholesky", - "diagonal", - "matrix_norm", - "matrix_rank", - "matrix_transpose", - "qr", - "svdvals", - "vecdot", - "vector_norm", -] +from __future__ import annotations + +from dask.array.linalg import svd +from ...common import _linalg +from ..._internal import get_xp + +# Exports +from dask.array.linalg import * # noqa: F403 +from dask.array import trace, outer + +# These functions are in both the main and linalg namespaces +from dask.array import matmul, tensordot +from ._aliases import matrix_transpose, vecdot + +import dask.array as da + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Union, Tuple + from ...common._typing import ndarray + +# cupy.linalg doesn't have __all__. If it is added, replace this with +# +# from cupy.linalg import __all__ as linalg_all +_n = {} +exec('from dask.array.linalg import *', _n) +del _n['__builtins__'] +if 'annotations' in _n: + del _n['annotations'] +linalg_all = list(_n) +del _n + +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +qr = get_xp(da)(_linalg.qr) +cholesky = get_xp(da)(_linalg.cholesky) +matrix_rank = get_xp(da)(_linalg.matrix_rank) +matrix_norm = get_xp(da)(_linalg.matrix_norm) + +def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: + # TODO: can't avoid computing U or V for dask + _, s, _ = svd(x) + return s + +vector_norm = get_xp(da)(_linalg.vector_norm) +diagonal = get_xp(da)(_linalg.diagonal) + +__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot", + "matrix_transpose", "vecdot", "EighResult", + "QRResult", "SlogdetResult", "SVDResult", "qr", + "cholesky", "matrix_rank", "matrix_norm", "svdvals", + "vector_norm", "diagonal"] + +_all_ignore = ['get_xp', 'da', 'linalg_all'] diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 8ee9f711..adf20191 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,150 +1,10 @@ -from numpy import * # noqa: F401, F403 -from numpy import __all__ as _numpy_all +from numpy import * # noqa: F403 # from numpy import * doesn't overwrite these builtin names -from numpy import abs, max, min, round - -from ..common._helpers import ( - array_namespace, - device, - get_namespace, - is_array_api_obj, - size, - to_device, -) +from numpy import abs, max, min, round # noqa: F401 # These imports may overwrite names from the import * above. -from ._aliases import ( - UniqueAllResult, - UniqueCountsResult, - UniqueInverseResult, - acos, - acosh, - arange, - argsort, - asarray, - asarray_numpy, - asin, - asinh, - astype, - atan, - atan2, - atanh, - bitwise_invert, - bitwise_left_shift, - bitwise_right_shift, - bool, - ceil, - concat, - empty, - empty_like, - eye, - floor, - full, - full_like, - isdtype, - linspace, - matmul, - matrix_transpose, - nonzero, - ones, - ones_like, - permute_dims, - pow, - prod, - reshape, - sort, - std, - sum, - tensordot, - trunc, - unique_all, - unique_counts, - unique_inverse, - unique_values, - var, - vecdot, - zeros, - zeros_like, -) - -__all__ = [] - -__all__ += _numpy_all - -__all__ += [ - "abs", - "max", - "min", - "round", -] - -__all__ += [ - "array_namespace", - "device", - "get_namespace", - "is_array_api_obj", - "size", - "to_device", -] - -__all__ += [ - "UniqueAllResult", - "UniqueCountsResult", - "UniqueInverseResult", - "acos", - "acosh", - "arange", - "argsort", - "asarray", - "asarray_numpy", - "asin", - "asinh", - "astype", - "atan", - "atan2", - "atanh", - "bitwise_invert", - "bitwise_left_shift", - "bitwise_right_shift", - "bool", - "ceil", - "concat", - "empty", - "empty_like", - "eye", - "floor", - "full", - "full_like", - "isdtype", - "linspace", - "matmul", - "matrix_transpose", - "nonzero", - "ones", - "ones_like", - "permute_dims", - "pow", - "prod", - "reshape", - "sort", - "std", - "sum", - "tensordot", - "trunc", - "unique_all", - "unique_counts", - "unique_inverse", - "unique_values", - "var", - "zeros", - "zeros_like", -] - -__all__ += [ - "matrix_transpose", - "vecdot", -] +from ._aliases import * # noqa: F403 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -153,6 +13,10 @@ # # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. -__import__(__package__ + ".linalg") +__import__(__package__ + '.linalg') + +from .linalg import matrix_transpose, vecdot # noqa: F401 + +from ..common._helpers import * # noqa: F403 -__array_api_version__ = "2022.12" +__array_api_version__ = '2022.12' diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index ee1c1557..1201d798 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -2,14 +2,15 @@ from functools import partial -import numpy as np +from ..common import _aliases from .._internal import get_xp -from ..common import _aliases, _linalg -asarray = asarray_numpy = partial(_aliases._asarray, namespace="numpy") +asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy') asarray.__doc__ = _aliases._asarray.__doc__ +del partial +import numpy as np bool = np.bool_ # Basic renames @@ -63,37 +64,18 @@ # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np, "vecdot"): +if hasattr(np, 'vecdot'): vecdot = np.vecdot else: vecdot = get_xp(np)(_aliases.vecdot) -if hasattr(np, "isdtype"): +if hasattr(np, 'isdtype'): isdtype = np.isdtype else: isdtype = get_xp(np)(_aliases.isdtype) +__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos', + 'acosh', 'asin', 'asinh', 'atan', 'atan2', + 'atanh', 'bitwise_left_shift', 'bitwise_invert', + 'bitwise_right_shift', 'concat', 'pow'] -cross = get_xp(np)(_linalg.cross) -outer = get_xp(np)(_linalg.outer) -EighResult = _linalg.EighResult -QRResult = _linalg.QRResult -SlogdetResult = _linalg.SlogdetResult -SVDResult = _linalg.SVDResult -eigh = get_xp(np)(_linalg.eigh) -qr = get_xp(np)(_linalg.qr) -slogdet = get_xp(np)(_linalg.slogdet) -svd = get_xp(np)(_linalg.svd) -cholesky = get_xp(np)(_linalg.cholesky) -matrix_rank = get_xp(np)(_linalg.matrix_rank) -pinv = get_xp(np)(_linalg.pinv) -matrix_norm = get_xp(np)(_linalg.matrix_norm) -svdvals = get_xp(np)(_linalg.svdvals) -diagonal = get_xp(np)(_linalg.diagonal) -trace = get_xp(np)(_linalg.trace) - -# These functions are completely new here. If the library already has them -# (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np.linalg, "vector_norm"): - vector_norm = np.linalg.vector_norm -else: - vector_norm = get_xp(np)(_linalg.vector_norm) +_all_ignore = ['np', 'get_xp'] diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index 53585f70..c5ebb5ab 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -1,9 +1,9 @@ from __future__ import annotations __all__ = [ + "ndarray", "Device", "Dtype", - "ndarray", ] import sys diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index fbe37cd7..42a998c7 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -1,63 +1,42 @@ -import numpy as _np - -from .._internal import _get_all_public_members - -_numpy_linalg_all = _get_all_public_members(_np.linalg) - -for _name in _numpy_linalg_all: - globals()[_name] = getattr(_np.linalg, _name) - - -from ._aliases import ( # noqa: E402 - EighResult, - QRResult, - SlogdetResult, - SVDResult, - cholesky, - cross, - diagonal, - eigh, - matmul, - matrix_norm, - matrix_rank, - matrix_transpose, - outer, - pinv, - qr, - slogdet, - svd, - svdvals, - tensordot, - trace, - vecdot, - vector_norm, -) - -__all__ = [] - -__all__ += _numpy_linalg_all - -__all__ += [ - "EighResult", - "QRResult", - "SVDResult", - "SlogdetResult", - "cholesky", - "cross", - "diagonal", - "eigh", - "matmul", - "matrix_norm", - "matrix_rank", - "matrix_transpose", - "outer", - "pinv", - "qr", - "slogdet", - "svd", - "svdvals", - "tensordot", - "trace", - "vecdot", - "vector_norm", -] +from numpy.linalg import * # noqa: F403 +from numpy.linalg import __all__ as linalg_all + +from ..common import _linalg +from .._internal import get_xp + +# These functions are in both the main and linalg namespaces +from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 + +import numpy as np + +cross = get_xp(np)(_linalg.cross) +outer = get_xp(np)(_linalg.outer) +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +eigh = get_xp(np)(_linalg.eigh) +qr = get_xp(np)(_linalg.qr) +slogdet = get_xp(np)(_linalg.slogdet) +svd = get_xp(np)(_linalg.svd) +cholesky = get_xp(np)(_linalg.cholesky) +matrix_rank = get_xp(np)(_linalg.matrix_rank) +pinv = get_xp(np)(_linalg.pinv) +matrix_norm = get_xp(np)(_linalg.matrix_norm) +svdvals = get_xp(np)(_linalg.svdvals) +diagonal = get_xp(np)(_linalg.diagonal) +trace = get_xp(np)(_linalg.trace) + +# These functions are completely new here. If the library already has them +# (i.e., numpy 2.0), use the library version instead of our wrapper. +if hasattr(np.linalg, 'vector_norm'): + vector_norm = np.linalg.vector_norm +else: + vector_norm = get_xp(np)(_linalg.vector_norm) + +__all__ = linalg_all + _linalg.__all__ + +del get_xp +del np +del linalg_all +del _linalg diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 6492839f..59898aab 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -1,189 +1,22 @@ -# Several names are not included in the above import * -import torch as _torch -from torch import * # noqa: F401, F403 - -from .._internal import _get_all_public_members - - -def exlcude(name): - if ( - name.startswith("_") - or name.endswith("_") - or "cuda" in name - or "cpu" in name - or "backward" in name - ): - return True - return False - - -_torch_all = _get_all_public_members(_torch, exclude=exlcude, extend_all=True) - -for _name in _torch_all: - globals()[_name] = getattr(_torch, _name) - +from torch import * # noqa: F403 -from ..common._helpers import ( # noqa: E402 - array_namespace, - device, - get_namespace, - is_array_api_obj, - size, - to_device, -) +# Several names are not included in the above import * +import torch +for n in dir(torch): + if (n.startswith('_') + or n.endswith('_') + or 'cuda' in n + or 'cpu' in n + or 'backward' in n): + continue + exec(n + ' = torch.' + n) # These imports may overwrite names from the import * above. -from ._aliases import ( # noqa: E402 - add, - all, - any, - arange, - astype, - atan2, - bitwise_and, - bitwise_invert, - bitwise_left_shift, - bitwise_or, - bitwise_right_shift, - bitwise_xor, - broadcast_arrays, - broadcast_to, - can_cast, - concat, - divide, - empty, - equal, - expand_dims, - eye, - flip, - floor_divide, - full, - greater, - greater_equal, - isdtype, - less, - less_equal, - linspace, - logaddexp, - matmul, - matrix_transpose, - max, - mean, - min, - multiply, - newaxis, - nonzero, - not_equal, - ones, - permute_dims, - pow, - prod, - remainder, - reshape, - result_type, - roll, - sort, - squeeze, - std, - subtract, - sum, - take, - tensordot, - tril, - triu, - unique_all, - unique_counts, - unique_inverse, - unique_values, - var, - vecdot, - where, - zeros, -) - -__all__ = [] - -__all__ += _torch_all - -__all__ += [ - "array_namespace", - "device", - "get_namespace", - "is_array_api_obj", - "size", - "to_device", -] - -__all__ += [ - "add", - "all", - "any", - "arange", - "astype", - "atan2", - "bitwise_and", - "bitwise_invert", - "bitwise_left_shift", - "bitwise_or", - "bitwise_right_shift", - "bitwise_xor", - "broadcast_arrays", - "broadcast_to", - "can_cast", - "concat", - "divide", - "empty", - "equal", - "expand_dims", - "eye", - "flip", - "floor_divide", - "full", - "greater", - "greater_equal", - "isdtype", - "less", - "less_equal", - "linspace", - "logaddexp", - "matmul", - "matrix_transpose", - "max", - "mean", - "min", - "multiply", - "newaxis", - "nonzero", - "not_equal", - "ones", - "permute_dims", - "pow", - "prod", - "remainder", - "reshape", - "result_type", - "roll", - "sort", - "squeeze", - "std", - "subtract", - "sum", - "take", - "tensordot", - "tril", - "triu", - "unique_all", - "unique_counts", - "unique_inverse", - "unique_values", - "var", - "vecdot", - "where", - "zeros", -] - +from ._aliases import * # noqa: F403 # See the comment in the numpy __init__.py -__import__(__package__ + ".linalg") +__import__(__package__ + '.linalg') + +from ..common._helpers import * # noqa: F403 -__array_api_version__ = "2022.12" +__array_api_version__ = '2022.12' diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 23cd5219..bfa7610b 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,23 +1,19 @@ from __future__ import annotations -from builtins import all as builtin_all -from builtins import any as builtin_any -from functools import wraps -from typing import TYPE_CHECKING - -import torch +from functools import wraps as _wraps +from builtins import all as _builtin_all, any as _builtin_any +from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose, + vecdot as _aliases_vecdot) from .._internal import get_xp -from ..common._aliases import UniqueAllResult, UniqueCountsResult, UniqueInverseResult -from ..common._aliases import matrix_transpose as _aliases_matrix_transpose -from ..common._aliases import vecdot as _aliases_vecdot +import torch + +from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import List, Optional, Sequence, Tuple, Union - - from torch import dtype as Dtype - from ..common._typing import Device + from torch import dtype as Dtype array = torch.Tensor @@ -88,7 +84,7 @@ def _two_arg(f): - @wraps(f) + @_wraps(f) def _f(x1, x2, /, **kwargs): x1, x2 = _fix_promotion(x1, x2) return f(x1, x2, **kwargs) @@ -511,7 +507,7 @@ def arange(start: Union[int, float], start, stop = 0, start if step > 0 and stop <= start or step < 0 and stop >= start: if dtype is None: - if builtin_all(isinstance(i, int) for i in [start, stop, step]): + if _builtin_all(isinstance(i, int) for i in [start, stop, step]): dtype = torch.int64 else: dtype = torch.float32 @@ -603,6 +599,11 @@ def broadcast_arrays(*arrays: array) -> List[array]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) return [torch.broadcast_to(a, shape) for a in arrays] +# Note that these named tuples aren't actually part of the standard namespace, +# but I don't see any issue with exporting the names here regardless. +from ..common._aliases import (UniqueAllResult, UniqueCountsResult, + UniqueInverseResult) + # https://github.com/pytorch/pytorch/issues/70920 def unique_all(x: array) -> UniqueAllResult: # torch.unique doesn't support returning indices. @@ -667,7 +668,7 @@ def isdtype( for more details """ if isinstance(kind, tuple) and _tuple: - return builtin_any(isdtype(dtype, k, _tuple=False) for k in kind) + return _builtin_any(isdtype(dtype, k, _tuple=False) for k in kind) elif isinstance(kind, str): if kind == 'bool': return dtype == torch.bool @@ -695,42 +696,19 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - axis = 0 return torch.index_select(x, axis, indices, **kwargs) - - -# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the -# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 -def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - return torch.linalg.cross(x1, x2, dim=axis) - -def vecdot_linalg(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: - from ._aliases import isdtype - - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - - # torch.linalg.vecdot doesn't support integer dtypes - if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): - if kwargs: - raise RuntimeError("vecdot kwargs not supported for integral dtypes") - ndim = max(x1.ndim, x2.ndim) - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) - if x1_shape[axis] != x2_shape[axis]: - raise ValueError("x1 and x2 must have the same size along the given axis") - - x1_, x2_ = torch.broadcast_tensors(x1, x2) - x1_ = torch.moveaxis(x1_, axis, -1) - x2_ = torch.moveaxis(x2_, axis, -1) - - res = x1_[..., None, :] @ x2_[..., None] - return res[..., 0, 0] - return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) - -def solve(x1: array, x2: array, /, **kwargs) -> array: - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - return torch.linalg.solve(x1, x2, **kwargs) - -# torch.trace doesn't support the offset argument and doesn't support stacking -def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: - # Use our wrapped sum to make sure it does upcasting correctly - return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) \ No newline at end of file +__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', + 'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', + 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide', + 'equal', 'floor_divide', 'greater', 'greater_equal', 'less', + 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow', + 'remainder', 'subtract', 'max', 'min', 'sort', 'prod', 'sum', + 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', + 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', + 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', + 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', + 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', + 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', + 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', + 'take'] + +_all_ignore = ['torch', 'get_xp'] diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 160f074b..63f0135b 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,34 +1,67 @@ -import torch as _torch +from __future__ import annotations -from .._internal import _get_all_public_members +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import torch + array = torch.Tensor + from torch import dtype as Dtype + from typing import Optional -_torch_linalg_all = _get_all_public_members(_torch.linalg) +from ._aliases import _fix_promotion, sum -for _name in _torch_linalg_all: - globals()[_name] = getattr(_torch.linalg, _name) +from torch.linalg import * # noqa: F403 + +# torch.linalg doesn't define __all__ +# from torch.linalg import __all__ as linalg_all +from torch import linalg as torch_linalg +linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] # outer is implemented in torch but aren't in the linalg namespace -outer = _torch.outer - -from ._aliases import ( # noqa: E402 - matrix_transpose, - solve, - sum, - tensordot, - trace, - vecdot_linalg as vecdot, -) - -__all__ = [] - -__all__ += _torch_linalg_all - -__all__ += [ - "matrix_transpose", - "outer", - "solve", - "sum", - "tensordot", - "trace", - "vecdot", -] +from torch import outer +# These functions are in both the main and linalg namespaces +from ._aliases import matmul, matrix_transpose, tensordot + +# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the +# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 +def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return torch_linalg.cross(x1, x2, dim=axis) + +def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: + from ._aliases import isdtype + + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + + # torch.linalg.vecdot doesn't support integer dtypes + if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): + if kwargs: + raise RuntimeError("vecdot kwargs not supported for integral dtypes") + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + x1_, x2_ = torch.broadcast_tensors(x1, x2) + x1_ = torch.moveaxis(x1_, axis, -1) + x2_ = torch.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return res[..., 0, 0] + return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) + +def solve(x1: array, x2: array, /, **kwargs) -> array: + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return torch.linalg.solve(x1, x2, **kwargs) + +# torch.trace doesn't support the offset argument and doesn't support stacking +def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: + # Use our wrapped sum to make sure it does upcasting correctly + return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) + +__all__ = linalg_all + ['outer', 'trace', 'matmul', 'matrix_transpose', 'tensordot', + 'vecdot', 'solve'] + +_all_ignore = ['torch_linalg', 'sum'] + +del linalg_all diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 00000000..53e8596c --- /dev/null +++ b/ruff.toml @@ -0,0 +1,13 @@ +[lint] +preview = true +select = [ +# Defaults +"E4", "E7", "E9", "F", +# Undefined export +"F822", +# Useless import alias +"PLC0414" +] + +# Ignore module import not at top of file +ignore = ["E402"] diff --git a/tests/_helpers.py b/tests/_helpers.py index e05ae86c..23cb5db9 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -8,7 +8,7 @@ def import_(library, wrapper=False): if library == 'cupy': return pytest.importorskip(library) - if 'jax' in library and sys.version_info <= (3, 8): + if 'jax' in library and sys.version_info < (3, 9): pytest.skip('JAX array API support does not support Python 3.8') if wrapper: diff --git a/tests/test_all.py b/tests/test_all.py new file mode 100644 index 00000000..5b49fa14 --- /dev/null +++ b/tests/test_all.py @@ -0,0 +1,42 @@ +""" +Test that files that define __all__ aren't missing any exports. + +You can add names that shouldn't be exported to _all_ignore, like + +_all_ignore = ['sys'] + +This is preferable to del-ing the names as this will break any name that is +used inside of a function. Note that names starting with an underscore are automatically ignored. +""" + + +import sys + +from ._helpers import import_ + +import pytest + +@pytest.mark.parametrize("library", ["common", "cupy", "numpy", "torch", "dask.array"]) +def test_all(library): + import_(library, wrapper=True) + + for mod_name in sys.modules: + if 'array_api_compat.' + library not in mod_name: + continue + + module = sys.modules[mod_name] + + # TODO: We should define __all__ in the __init__.py files and test it + # there too. + if not hasattr(module, '__all__'): + continue + + dir_names = [n for n in dir(module) if not n.startswith('_')] + ignore_all_names = getattr(module, '_all_ignore', []) + ignore_all_names += ['annotations', 'TYPE_CHECKING'] + dir_names = set(dir_names) - set(ignore_all_names) + all_names = module.__all__ + + if set(dir_names) != set(all_names): + assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}" + assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"