From 7d4531d61b101548da1b2099fb9f7872d0a00a67 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Wed, 28 Apr 2021 14:50:02 -0400 Subject: [PATCH 1/9] TYP: Fix typing in ExtensionDtype registry --- pandas/core/dtypes/base.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index 9671c340a0a92..226755011b233 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -28,7 +28,7 @@ from pandas.core.arrays import ExtensionArray # To parameterize on same ExtensionDtype - E = TypeVar("E", bound="ExtensionDtype") + ExtensionDtypeT = TypeVar("ExtensionDtypeT", bound="ExtensionDtype") class ExtensionDtype: @@ -368,7 +368,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: return None -def register_extension_dtype(cls: type[E]) -> type[E]: +def register_extension_dtype(cls: type[ExtensionDtypeT]) -> type[ExtensionDtypeT]: """ Register an ExtensionType with pandas as class decorator. @@ -424,20 +424,23 @@ def register(self, dtype: type[ExtensionDtype]) -> None: self.dtypes.append(dtype) - def find(self, dtype: type[ExtensionDtype] | str) -> type[ExtensionDtype] | None: + def find( + self, dtype: type[ExtensionDtype] | ExtensionDtype | str + ) -> type[ExtensionDtype] | ExtensionDtype | None: """ Parameters ---------- - dtype : Type[ExtensionDtype] or str + dtype : Type[ExtensionDtype] or ExtensionDtype or str Returns ------- return the first matching dtype, otherwise return None """ if not isinstance(dtype, str): - dtype_type = dtype if not isinstance(dtype, type): dtype_type = type(dtype) + else: + dtype_type = dtype if issubclass(dtype_type, ExtensionDtype): return dtype From 41c5355faa14803bd338376d1b5b531882a478ae Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Wed, 28 Apr 2021 15:05:30 -0400 Subject: [PATCH 2/9] fix construct_from_string return type --- pandas/core/dtypes/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index 226755011b233..16cc9291f19b2 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -213,7 +213,7 @@ def construct_array_type(cls) -> type_t[ExtensionArray]: raise NotImplementedError @classmethod - def construct_from_string(cls, string: str): + def construct_from_string(cls, string: str) -> ExtensionDtype: r""" Construct this type from a string. From 03d4e6682afdc64c1cded22376d604ab60e3663a Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Wed, 28 Apr 2021 15:31:10 -0400 Subject: [PATCH 3/9] overloads for find. typing cls in ExtensionDtype --- pandas/core/dtypes/base.py | 23 +++++++++++++++++++---- pandas/core/dtypes/common.py | 4 +--- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index 16cc9291f19b2..8c3c51359c524 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -8,6 +8,7 @@ TYPE_CHECKING, Any, TypeVar, + overload, ) import numpy as np @@ -202,7 +203,7 @@ def names(self) -> list[str] | None: return None @classmethod - def construct_array_type(cls) -> type_t[ExtensionArray]: + def construct_array_type(cls: type_t[ExtensionDtypeT]) -> type_t[ExtensionArray]: """ Return the array type associated with this dtype. @@ -213,7 +214,9 @@ def construct_array_type(cls) -> type_t[ExtensionArray]: raise NotImplementedError @classmethod - def construct_from_string(cls, string: str) -> ExtensionDtype: + def construct_from_string( + cls: type_t[ExtensionDtypeT], string: str + ) -> ExtensionDtype: r""" Construct this type from a string. @@ -268,7 +271,7 @@ def construct_from_string(cls, string: str) -> ExtensionDtype: return cls() @classmethod - def is_dtype(cls, dtype: object) -> bool: + def is_dtype(cls: type_t[ExtensionDtypeT], dtype: object) -> bool: """ Check if we match 'dtype'. @@ -424,13 +427,25 @@ def register(self, dtype: type[ExtensionDtype]) -> None: self.dtypes.append(dtype) + @overload + def find(self, dtype: type[ExtensionDtype]) -> ExtensionDtype: + ... + + @overload + def find(self, dtype: ExtensionDtype) -> ExtensionDtype: + ... + + @overload + def find(self, dtype: str) -> ExtensionDtype | None: + ... + def find( self, dtype: type[ExtensionDtype] | ExtensionDtype | str ) -> type[ExtensionDtype] | ExtensionDtype | None: """ Parameters ---------- - dtype : Type[ExtensionDtype] or ExtensionDtype or str + dtype : ExtensionDtype class or instance or str Returns ------- diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index 593e42f7ed749..ddc389934daaa 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -1789,9 +1789,7 @@ def pandas_dtype(dtype) -> DtypeObj: # registered extension types result = registry.find(dtype) if result is not None: - # error: Incompatible return value type (got "Type[ExtensionDtype]", - # expected "Union[dtype, ExtensionDtype]") - return result # type: ignore[return-value] + return result # try a numpy dtype # raise a consistent TypeError if failed From 1a71e7113039edcb6447cebda44a0204709406de Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sun, 23 May 2021 11:18:18 -0400 Subject: [PATCH 4/9] fixes to use more of ExtensionDtypeT --- pandas/core/dtypes/base.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index 8c3c51359c524..1b9ea6711301e 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -8,6 +8,7 @@ TYPE_CHECKING, Any, TypeVar, + cast, overload, ) @@ -15,6 +16,7 @@ from pandas._typing import ( DtypeObj, + NpDtype, type_t, ) from pandas.errors import AbstractMethodError @@ -203,7 +205,7 @@ def names(self) -> list[str] | None: return None @classmethod - def construct_array_type(cls: type_t[ExtensionDtypeT]) -> type_t[ExtensionArray]: + def construct_array_type(cls) -> type_t[ExtensionArray]: """ Return the array type associated with this dtype. @@ -216,7 +218,7 @@ def construct_array_type(cls: type_t[ExtensionDtypeT]) -> type_t[ExtensionArray] @classmethod def construct_from_string( cls: type_t[ExtensionDtypeT], string: str - ) -> ExtensionDtype: + ) -> ExtensionDtypeT: r""" Construct this type from a string. @@ -271,7 +273,7 @@ def construct_from_string( return cls() @classmethod - def is_dtype(cls: type_t[ExtensionDtypeT], dtype: object) -> bool: + def is_dtype(cls, dtype: object) -> bool: """ Check if we match 'dtype'. @@ -428,42 +430,51 @@ def register(self, dtype: type[ExtensionDtype]) -> None: self.dtypes.append(dtype) @overload - def find(self, dtype: type[ExtensionDtype]) -> ExtensionDtype: + def find(self, dtype: type[ExtensionDtypeT]) -> ExtensionDtypeT: ... @overload - def find(self, dtype: ExtensionDtype) -> ExtensionDtype: + def find(self, dtype: ExtensionDtypeT) -> ExtensionDtypeT: ... @overload - def find(self, dtype: str) -> ExtensionDtype | None: + def find( + self, dtype: NpDtype | type_t[str | float | int | complex | bool | object] | str + ) -> ExtensionDtype | None: ... def find( - self, dtype: type[ExtensionDtype] | ExtensionDtype | str + self, + dtype: type[ExtensionDtype] + | ExtensionDtype + | NpDtype + | type_t[str | float | int | complex | bool | object], ) -> type[ExtensionDtype] | ExtensionDtype | None: """ Parameters ---------- - dtype : ExtensionDtype class or instance or str + dtype : ExtensionDtype class or instance or str or numpy dtype or python type Returns ------- return the first matching dtype, otherwise return None """ if not isinstance(dtype, str): + dtype_type: type[ExtensionDtype] | type if not isinstance(dtype, type): dtype_type = type(dtype) else: dtype_type = dtype if issubclass(dtype_type, ExtensionDtype): - return dtype + # cast needed here as mypy doesn't know we have figured + # out it is an ExtensionDtype + return cast(ExtensionDtype, dtype) return None - for dtype_type in self.dtypes: + for dtype_loop in self.dtypes: try: - return dtype_type.construct_from_string(dtype) + return dtype_loop.construct_from_string(dtype) except TypeError: pass From 37affba039b42cc95f54059eb82e2a43f24d2c5a Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sat, 3 Jul 2021 14:10:14 -0400 Subject: [PATCH 5/9] use type_t, reduce overloads, use npt.DTypeLike --- pandas/core/dtypes/base.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index 618b565c9fc5b..3e9964179919a 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -17,7 +17,7 @@ from pandas._libs.hashtable import object_hash from pandas._typing import ( DtypeObj, - NpDtype, + npt, type_t, ) from pandas.errors import AbstractMethodError @@ -373,7 +373,7 @@ def _can_hold_na(self) -> bool: return True -def register_extension_dtype(cls: type[ExtensionDtypeT]) -> type[ExtensionDtypeT]: +def register_extension_dtype(cls: type_t[ExtensionDtypeT]) -> type_t[ExtensionDtypeT]: """ Register an ExtensionType with pandas as class decorator. @@ -414,9 +414,9 @@ class Registry: """ def __init__(self): - self.dtypes: list[type[ExtensionDtype]] = [] + self.dtypes: list[type_t[ExtensionDtype]] = [] - def register(self, dtype: type[ExtensionDtype]) -> None: + def register(self, dtype: type_t[ExtensionDtype]) -> None: """ Parameters ---------- @@ -428,26 +428,16 @@ def register(self, dtype: type[ExtensionDtype]) -> None: self.dtypes.append(dtype) @overload - def find(self, dtype: type[ExtensionDtypeT]) -> ExtensionDtypeT: + def find(self, dtype: ExtensionDtypeT | type_t[ExtensionDtypeT]) -> ExtensionDtypeT: ... @overload - def find(self, dtype: ExtensionDtypeT) -> ExtensionDtypeT: + def find(self, dtype: npt.DTypeLike) -> ExtensionDtype | None: ... - @overload def find( - self, dtype: NpDtype | type_t[str | float | int | complex | bool | object] | str + self, dtype: type_t[ExtensionDtype] | ExtensionDtype | npt.DTypeLike ) -> ExtensionDtype | None: - ... - - def find( - self, - dtype: type[ExtensionDtype] - | ExtensionDtype - | NpDtype - | type_t[str | float | int | complex | bool | object], - ) -> type[ExtensionDtype] | ExtensionDtype | None: """ Parameters ---------- @@ -458,7 +448,7 @@ def find( return the first matching dtype, otherwise return None """ if not isinstance(dtype, str): - dtype_type: type[ExtensionDtype] | type + dtype_type: type_t if not isinstance(dtype, type): dtype_type = type(dtype) else: @@ -470,9 +460,9 @@ def find( return None - for dtype_loop in self.dtypes: + for dtype_type in self.dtypes: try: - return dtype_loop.construct_from_string(dtype) + return dtype_type.construct_from_string(dtype) except TypeError: pass From 3174b0fac24560ace5de9cf32962a769a2dc7f90 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sat, 3 Jul 2021 14:35:31 -0400 Subject: [PATCH 6/9] undo use of npt and put back type_t[ExtensionDtype] --- pandas/core/dtypes/base.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index 3e9964179919a..7faa192403dc2 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -17,7 +17,7 @@ from pandas._libs.hashtable import object_hash from pandas._typing import ( DtypeObj, - npt, + NpDtype, type_t, ) from pandas.errors import AbstractMethodError @@ -428,16 +428,26 @@ def register(self, dtype: type_t[ExtensionDtype]) -> None: self.dtypes.append(dtype) @overload - def find(self, dtype: ExtensionDtypeT | type_t[ExtensionDtypeT]) -> ExtensionDtypeT: + def find(self, dtype: type_t[ExtensionDtypeT]) -> type_t[ExtensionDtypeT]: ... @overload - def find(self, dtype: npt.DTypeLike) -> ExtensionDtype | None: + def find(self, dtype: ExtensionDtypeT) -> ExtensionDtypeT: + ... + + @overload + def find( + self, dtype: NpDtype | type_t[str | float | int | complex | bool | object] + ) -> type_t[ExtensionDtypeT] | ExtensionDtype | None: ... def find( - self, dtype: type_t[ExtensionDtype] | ExtensionDtype | npt.DTypeLike - ) -> ExtensionDtype | None: + self, + dtype: type_t[ExtensionDtype] + | ExtensionDtype + | NpDtype + | type_t[str | float | int | complex | bool | object], + ) -> type_t[ExtensionDtype] | ExtensionDtype | None: """ Parameters ---------- From 4d6dbdbd026fe0c054a62f8a9c62b4c17b8ce6f9 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sat, 3 Jul 2021 14:50:10 -0400 Subject: [PATCH 7/9] cast to the union --- pandas/core/dtypes/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index 7faa192403dc2..c27256a2c41d6 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -18,6 +18,7 @@ from pandas._typing import ( DtypeObj, NpDtype, + Union, type_t, ) from pandas.errors import AbstractMethodError @@ -465,8 +466,8 @@ def find( dtype_type = dtype if issubclass(dtype_type, ExtensionDtype): # cast needed here as mypy doesn't know we have figured - # out it is an ExtensionDtype - return cast(ExtensionDtype, dtype) + # out it is an ExtensionDtype or type_t[ExtensionDtype] + return cast(Union[ExtensionDtype, type_t[ExtensionDtype]], dtype) return None From eaefa6c4bf91ecf42185a87ee98819249c437652 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sat, 3 Jul 2021 15:57:01 -0400 Subject: [PATCH 8/9] remove Union in cast --- pandas/core/dtypes/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index c27256a2c41d6..abac3faa97db6 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -18,7 +18,6 @@ from pandas._typing import ( DtypeObj, NpDtype, - Union, type_t, ) from pandas.errors import AbstractMethodError @@ -467,7 +466,7 @@ def find( if issubclass(dtype_type, ExtensionDtype): # cast needed here as mypy doesn't know we have figured # out it is an ExtensionDtype or type_t[ExtensionDtype] - return cast(Union[ExtensionDtype, type_t[ExtensionDtype]], dtype) + return cast("ExtensionDtype | type_t[ExtensionDtype]", dtype) return None From d9651a4fbab096df0865f251675eab9637bfae89 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Sun, 4 Jul 2021 12:26:08 -0400 Subject: [PATCH 9/9] add 4th overload --- pandas/core/dtypes/base.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index abac3faa97db6..7dad8c61f4fc7 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -17,7 +17,7 @@ from pandas._libs.hashtable import object_hash from pandas._typing import ( DtypeObj, - NpDtype, + npt, type_t, ) from pandas.errors import AbstractMethodError @@ -435,18 +435,18 @@ def find(self, dtype: type_t[ExtensionDtypeT]) -> type_t[ExtensionDtypeT]: def find(self, dtype: ExtensionDtypeT) -> ExtensionDtypeT: ... + @overload + def find(self, dtype: str) -> ExtensionDtype | None: + ... + @overload def find( - self, dtype: NpDtype | type_t[str | float | int | complex | bool | object] - ) -> type_t[ExtensionDtypeT] | ExtensionDtype | None: + self, dtype: npt.DTypeLike + ) -> type_t[ExtensionDtype] | ExtensionDtype | None: ... def find( - self, - dtype: type_t[ExtensionDtype] - | ExtensionDtype - | NpDtype - | type_t[str | float | int | complex | bool | object], + self, dtype: type_t[ExtensionDtype] | ExtensionDtype | npt.DTypeLike ) -> type_t[ExtensionDtype] | ExtensionDtype | None: """ Parameters