diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index 5b7dadac5d914..7dad8c61f4fc7 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -8,6 +8,8 @@ TYPE_CHECKING, Any, TypeVar, + cast, + overload, ) import numpy as np @@ -15,6 +17,7 @@ from pandas._libs.hashtable import object_hash from pandas._typing import ( DtypeObj, + npt, type_t, ) from pandas.errors import AbstractMethodError @@ -29,7 +32,7 @@ from pandas.core.arrays import ExtensionArray # To parameterize on same ExtensionDtype - E = TypeVar("E", bound="ExtensionDtype") + ExtensionDtypeT = TypeVar("ExtensionDtypeT", bound="ExtensionDtype") class ExtensionDtype: @@ -206,7 +209,9 @@ def construct_array_type(cls) -> type_t[ExtensionArray]: raise AbstractMethodError(cls) @classmethod - def construct_from_string(cls, string: str): + def construct_from_string( + cls: type_t[ExtensionDtypeT], string: str + ) -> ExtensionDtypeT: r""" Construct this type from a string. @@ -368,7 +373,7 @@ def _can_hold_na(self) -> bool: return True -def register_extension_dtype(cls: type[E]) -> type[E]: +def register_extension_dtype(cls: type_t[ExtensionDtypeT]) -> type_t[ExtensionDtypeT]: """ Register an ExtensionType with pandas as class decorator. @@ -409,9 +414,9 @@ class Registry: """ def __init__(self): - self.dtypes: list[type[ExtensionDtype]] = [] + self.dtypes: list[type_t[ExtensionDtype]] = [] - def register(self, dtype: type[ExtensionDtype]) -> None: + def register(self, dtype: type_t[ExtensionDtype]) -> None: """ Parameters ---------- @@ -422,22 +427,46 @@ def register(self, dtype: type[ExtensionDtype]) -> None: self.dtypes.append(dtype) - def find(self, dtype: type[ExtensionDtype] | str) -> type[ExtensionDtype] | None: + @overload + def find(self, dtype: type_t[ExtensionDtypeT]) -> type_t[ExtensionDtypeT]: + ... + + @overload + def find(self, dtype: ExtensionDtypeT) -> ExtensionDtypeT: + ... + + @overload + def find(self, dtype: str) -> ExtensionDtype | None: + ... + + @overload + def find( + self, dtype: npt.DTypeLike + ) -> type_t[ExtensionDtype] | ExtensionDtype | None: + ... + + def find( + self, dtype: type_t[ExtensionDtype] | ExtensionDtype | npt.DTypeLike + ) -> type_t[ExtensionDtype] | ExtensionDtype | None: """ Parameters ---------- - dtype : Type[ExtensionDtype] or str + dtype : ExtensionDtype class or instance or str or numpy dtype or python type Returns ------- return the first matching dtype, otherwise return None """ if not isinstance(dtype, str): - dtype_type = dtype + dtype_type: type_t if not isinstance(dtype, type): dtype_type = type(dtype) + else: + dtype_type = dtype if issubclass(dtype_type, ExtensionDtype): - return dtype + # cast needed here as mypy doesn't know we have figured + # out it is an ExtensionDtype or type_t[ExtensionDtype] + return cast("ExtensionDtype | type_t[ExtensionDtype]", dtype) return None diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index 08287cc296006..3a870c2287584 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -1765,9 +1765,7 @@ def pandas_dtype(dtype) -> DtypeObj: # registered extension types result = registry.find(dtype) if result is not None: - # error: Incompatible return value type (got "Type[ExtensionDtype]", - # expected "Union[dtype, ExtensionDtype]") - return result # type: ignore[return-value] + return result # try a numpy dtype # raise a consistent TypeError if failed