Skip to content

Commit 5d9d984

Browse files
authored
fix fallback isdtype method (#9250)
1 parent 71fce9b commit 5d9d984

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

xarray/core/npcompat.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,18 @@
2828
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
2929
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
3030
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
from __future__ import annotations
32+
33+
from typing import Any
3134

3235
try:
3336
# requires numpy>=2.0
3437
from numpy import isdtype # type: ignore[attr-defined,unused-ignore]
3538
except ImportError:
3639
import numpy as np
40+
from numpy.typing import DTypeLike
3741

38-
dtype_kinds = {
42+
kind_mapping = {
3943
"bool": np.bool_,
4044
"signed integer": np.signedinteger,
4145
"unsigned integer": np.unsignedinteger,
@@ -45,16 +49,25 @@
4549
"numeric": np.number,
4650
}
4751

48-
def isdtype(dtype, kind):
52+
def isdtype(
53+
dtype: np.dtype[Any] | type[Any], kind: DTypeLike | tuple[DTypeLike, ...]
54+
) -> bool:
4955
kinds = kind if isinstance(kind, tuple) else (kind,)
56+
str_kinds = {k for k in kinds if isinstance(k, str)}
57+
type_kinds = {k.type for k in kinds if isinstance(k, np.dtype)}
5058

51-
unknown_dtypes = [kind for kind in kinds if kind not in dtype_kinds]
52-
if unknown_dtypes:
53-
raise ValueError(f"unknown dtype kinds: {unknown_dtypes}")
59+
if unknown_kind_types := set(kinds) - str_kinds - type_kinds:
60+
raise TypeError(
61+
f"kind must be str, np.dtype or a tuple of these, got {unknown_kind_types}"
62+
)
63+
if unknown_kinds := {k for k in str_kinds if k not in kind_mapping}:
64+
raise ValueError(
65+
f"unknown kind: {unknown_kinds}, must be a np.dtype or one of {list(kind_mapping)}"
66+
)
5467

5568
# verified the dtypes already, no need to check again
56-
translated_kinds = [dtype_kinds[kind] for kind in kinds]
69+
translated_kinds = {kind_mapping[k] for k in str_kinds} | type_kinds
5770
if isinstance(dtype, np.generic):
58-
return any(isinstance(dtype, kind) for kind in translated_kinds)
71+
return isinstance(dtype, translated_kinds)
5972
else:
60-
return any(np.issubdtype(dtype, kind) for kind in translated_kinds)
73+
return any(np.issubdtype(dtype, k) for k in translated_kinds)

0 commit comments

Comments
 (0)