Skip to content

Commit bb08817

Browse files
authored
ENH: allow EADtype to specify _supports_2d (#54832)
1 parent 75781c2 commit bb08817

File tree

5 files changed

+48
-7
lines changed

5 files changed

+48
-7
lines changed

pandas/core/dtypes/base.py

+27
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,33 @@ def index_class(self) -> type_t[Index]:
418418

419419
return Index
420420

421+
@property
422+
def _supports_2d(self) -> bool:
423+
"""
424+
Do ExtensionArrays with this dtype support 2D arrays?
425+
426+
Historically ExtensionArrays were limited to 1D. By returning True here,
427+
authors can indicate that their arrays support 2D instances. This can
428+
improve performance in some cases, particularly operations with `axis=1`.
429+
430+
Arrays that support 2D values should:
431+
432+
- implement Array.reshape
433+
- subclass the Dim2CompatTests in tests.extension.base
434+
- _concat_same_type should support `axis` keyword
435+
- _reduce and reductions should support `axis` keyword
436+
"""
437+
return False
438+
439+
@property
440+
def _can_fast_transpose(self) -> bool:
441+
"""
442+
Is transposing an array with this dtype zero-copy?
443+
444+
Only relevant for cases where _supports_2d is True.
445+
"""
446+
return False
447+
421448

422449
class StorageExtensionDtype(ExtensionDtype):
423450
"""ExtensionDtype that may be backed by more than one implementation."""

pandas/core/dtypes/common.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -1256,13 +1256,7 @@ def is_1d_only_ea_dtype(dtype: DtypeObj | None) -> bool:
12561256
"""
12571257
Analogue to is_extension_array_dtype but excluding DatetimeTZDtype.
12581258
"""
1259-
# Note: if other EA dtypes are ever held in HybridBlock, exclude those
1260-
# here too.
1261-
# NB: need to check DatetimeTZDtype and not is_datetime64tz_dtype
1262-
# to exclude ArrowTimestampUSDtype
1263-
return isinstance(dtype, ExtensionDtype) and not isinstance(
1264-
dtype, (DatetimeTZDtype, PeriodDtype)
1265-
)
1259+
return isinstance(dtype, ExtensionDtype) and not dtype._supports_2d
12661260

12671261

12681262
def is_extension_array_dtype(arr_or_dtype) -> bool:

pandas/core/dtypes/dtypes.py

+8
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
213213
base = np.dtype("O")
214214
_metadata = ("categories", "ordered")
215215
_cache_dtypes: dict[str_type, PandasExtensionDtype] = {}
216+
_supports_2d = False
217+
_can_fast_transpose = False
216218

217219
def __init__(self, categories=None, ordered: Ordered = False) -> None:
218220
self._finalize(categories, ordered, fastpath=False)
@@ -730,6 +732,8 @@ class DatetimeTZDtype(PandasExtensionDtype):
730732
_metadata = ("unit", "tz")
731733
_match = re.compile(r"(datetime64|M8)\[(?P<unit>.+), (?P<tz>.+)\]")
732734
_cache_dtypes: dict[str_type, PandasExtensionDtype] = {}
735+
_supports_2d = True
736+
_can_fast_transpose = True
733737

734738
@property
735739
def na_value(self) -> NaTType:
@@ -973,6 +977,8 @@ class PeriodDtype(PeriodDtypeBase, PandasExtensionDtype):
973977
_cache_dtypes: dict[BaseOffset, int] = {} # type: ignore[assignment]
974978
__hash__ = PeriodDtypeBase.__hash__
975979
_freq: BaseOffset
980+
_supports_2d = True
981+
_can_fast_transpose = True
976982

977983
def __new__(cls, freq) -> PeriodDtype: # noqa: PYI034
978984
"""
@@ -1435,6 +1441,8 @@ class NumpyEADtype(ExtensionDtype):
14351441
"""
14361442

14371443
_metadata = ("_dtype",)
1444+
_supports_2d = False
1445+
_can_fast_transpose = False
14381446

14391447
def __init__(self, dtype: npt.DTypeLike | NumpyEADtype | None) -> None:
14401448
if isinstance(dtype, NumpyEADtype):

pandas/tests/extension/base/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class ExtensionTests(
8585
BaseReduceTests,
8686
BaseReshapingTests,
8787
BaseSetitemTests,
88+
Dim2CompatTests,
8889
):
8990
pass
9091

pandas/tests/extension/base/dim2.py

+11
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@ class Dim2CompatTests:
2020
# Note: these are ONLY for ExtensionArray subclasses that support 2D arrays.
2121
# i.e. not for pyarrow-backed EAs.
2222

23+
@pytest.fixture(autouse=True)
24+
def skip_if_doesnt_support_2d(self, dtype, request):
25+
if not dtype._supports_2d:
26+
node = request.node
27+
# In cases where we are mixed in to ExtensionTests, we only want to
28+
# skip tests that are defined in Dim2CompatTests
29+
test_func = node._obj
30+
if test_func.__qualname__.startswith("Dim2CompatTests"):
31+
# TODO: is there a less hacky way of checking this?
32+
pytest.skip("Test is only for EAs that support 2D.")
33+
2334
def test_transpose(self, data):
2435
arr2d = data.repeat(2).reshape(-1, 2)
2536
shape = arr2d.shape

0 commit comments

Comments
 (0)