|
28 | 28 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
29 | 29 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
30 | 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 31 | +from __future__ import annotations |
| 32 | + |
| 33 | +from typing import Any |
31 | 34 |
|
32 | 35 | try:
|
33 | 36 | # requires numpy>=2.0
|
34 | 37 | from numpy import isdtype # type: ignore[attr-defined,unused-ignore]
|
35 | 38 | except ImportError:
|
36 | 39 | import numpy as np
|
| 40 | + from numpy.typing import DTypeLike |
37 | 41 |
|
38 |
| - dtype_kinds = { |
| 42 | + kind_mapping = { |
39 | 43 | "bool": np.bool_,
|
40 | 44 | "signed integer": np.signedinteger,
|
41 | 45 | "unsigned integer": np.unsignedinteger,
|
|
45 | 49 | "numeric": np.number,
|
46 | 50 | }
|
47 | 51 |
|
48 |
| - def isdtype(dtype, kind): |
| 52 | + def isdtype( |
| 53 | + dtype: np.dtype[Any] | type[Any], kind: DTypeLike | tuple[DTypeLike, ...] |
| 54 | + ) -> bool: |
49 | 55 | kinds = kind if isinstance(kind, tuple) else (kind,)
|
| 56 | + str_kinds = {k for k in kinds if isinstance(k, str)} |
| 57 | + type_kinds = {k.type for k in kinds if isinstance(k, np.dtype)} |
50 | 58 |
|
51 |
| - unknown_dtypes = [kind for kind in kinds if kind not in dtype_kinds] |
52 |
| - if unknown_dtypes: |
53 |
| - raise ValueError(f"unknown dtype kinds: {unknown_dtypes}") |
| 59 | + if unknown_kind_types := set(kinds) - str_kinds - type_kinds: |
| 60 | + raise TypeError( |
| 61 | + f"kind must be str, np.dtype or a tuple of these, got {unknown_kind_types}" |
| 62 | + ) |
| 63 | + if unknown_kinds := {k for k in str_kinds if k not in kind_mapping}: |
| 64 | + raise ValueError( |
| 65 | + f"unknown kind: {unknown_kinds}, must be a np.dtype or one of {list(kind_mapping)}" |
| 66 | + ) |
54 | 67 |
|
55 | 68 | # verified the dtypes already, no need to check again
|
56 |
| - translated_kinds = [dtype_kinds[kind] for kind in kinds] |
| 69 | + translated_kinds = {kind_mapping[k] for k in str_kinds} | type_kinds |
57 | 70 | if isinstance(dtype, np.generic):
|
58 |
| - return any(isinstance(dtype, kind) for kind in translated_kinds) |
| 71 | + return isinstance(dtype, translated_kinds) |
59 | 72 | else:
|
60 |
| - return any(np.issubdtype(dtype, kind) for kind in translated_kinds) |
| 73 | + return any(np.issubdtype(dtype, k) for k in translated_kinds) |
0 commit comments