Skip to content

Commit 1be9f58

Browse files
committed
TYP: Fix some typehints for ExtensionDtype
1 parent 3513f59 commit 1be9f58

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

pandas/core/dtypes/base.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import (
88
TYPE_CHECKING,
99
Any,
10+
TypeVar,
1011
)
1112

1213
import numpy as np
@@ -26,6 +27,9 @@
2627
if TYPE_CHECKING:
2728
from pandas.core.arrays import ExtensionArray
2829

30+
# To parameterize on same ExtensionDtype
31+
E = TypeVar("E", bound="ExtensionDtype")
32+
2933

3034
class ExtensionDtype:
3135
"""
@@ -151,7 +155,7 @@ def na_value(self) -> object:
151155
return np.nan
152156

153157
@property
154-
def type(self) -> type[Any]:
158+
def type(self) -> type_t[Any]:
155159
"""
156160
The scalar type for the array, e.g. ``int``
157161
@@ -209,7 +213,7 @@ def construct_array_type(cls) -> type_t[ExtensionArray]:
209213
raise NotImplementedError
210214

211215
@classmethod
212-
def construct_from_string(cls, string: str):
216+
def construct_from_string(cls, string: str) -> ExtensionDtype:
213217
r"""
214218
Construct this type from a string.
215219
@@ -364,7 +368,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
364368
return None
365369

366370

367-
def register_extension_dtype(cls: type[ExtensionDtype]) -> type[ExtensionDtype]:
371+
def register_extension_dtype(cls: type[E]) -> type[E]:
368372
"""
369373
Register an ExtensionType with pandas as class decorator.
370374
@@ -420,7 +424,7 @@ def register(self, dtype: type[ExtensionDtype]) -> None:
420424

421425
self.dtypes.append(dtype)
422426

423-
def find(self, dtype: type[ExtensionDtype] | str) -> type[ExtensionDtype] | None:
427+
def find(self, dtype: type[E] | E | str) -> type[E] | E | ExtensionDtype | None:
424428
"""
425429
Parameters
426430
----------
@@ -431,8 +435,9 @@ def find(self, dtype: type[ExtensionDtype] | str) -> type[ExtensionDtype] | None
431435
return the first matching dtype, otherwise return None
432436
"""
433437
if not isinstance(dtype, str):
434-
dtype_type = dtype
435-
if not isinstance(dtype, type):
438+
if isinstance(dtype, type):
439+
dtype_type = dtype
440+
else:
436441
dtype_type = type(dtype)
437442
if issubclass(dtype_type, ExtensionDtype):
438443
return dtype

0 commit comments

Comments
 (0)