7
7
from typing import (
8
8
TYPE_CHECKING ,
9
9
Any ,
10
+ TypeVar ,
10
11
)
11
12
12
13
import numpy as np
26
27
if TYPE_CHECKING :
27
28
from pandas .core .arrays import ExtensionArray
28
29
30
+ # To parameterize on same ExtensionDtype
31
+ E = TypeVar ("E" , bound = "ExtensionDtype" )
32
+
29
33
30
34
class ExtensionDtype :
31
35
"""
@@ -151,7 +155,7 @@ def na_value(self) -> object:
151
155
return np .nan
152
156
153
157
@property
154
- def type (self ) -> type [Any ]:
158
+ def type (self ) -> type_t [Any ]:
155
159
"""
156
160
The scalar type for the array, e.g. ``int``
157
161
@@ -209,7 +213,7 @@ def construct_array_type(cls) -> type_t[ExtensionArray]:
209
213
raise NotImplementedError
210
214
211
215
@classmethod
212
- def construct_from_string (cls , string : str ):
216
+ def construct_from_string (cls , string : str ) -> ExtensionDtype :
213
217
r"""
214
218
Construct this type from a string.
215
219
@@ -364,7 +368,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
364
368
return None
365
369
366
370
367
- def register_extension_dtype (cls : type [ExtensionDtype ]) -> type [ExtensionDtype ]:
371
+ def register_extension_dtype (cls : type [E ]) -> type [E ]:
368
372
"""
369
373
Register an ExtensionType with pandas as class decorator.
370
374
@@ -420,7 +424,7 @@ def register(self, dtype: type[ExtensionDtype]) -> None:
420
424
421
425
self .dtypes .append (dtype )
422
426
423
- def find (self , dtype : type [ExtensionDtype ] | str ) -> type [ExtensionDtype ] | None :
427
+ def find (self , dtype : type [E ] | E | str ) -> type [E ] | E | ExtensionDtype | None :
424
428
"""
425
429
Parameters
426
430
----------
@@ -431,8 +435,9 @@ def find(self, dtype: type[ExtensionDtype] | str) -> type[ExtensionDtype] | None
431
435
return the first matching dtype, otherwise return None
432
436
"""
433
437
if not isinstance (dtype , str ):
434
- dtype_type = dtype
435
- if not isinstance (dtype , type ):
438
+ if isinstance (dtype , type ):
439
+ dtype_type = dtype
440
+ else :
436
441
dtype_type = type (dtype )
437
442
if issubclass (dtype_type , ExtensionDtype ):
438
443
return dtype
0 commit comments