Skip to content

Commit 0a3994b

Browse files
committed
REF/API: Stricter extension checking.
Removes is_extension_array_dtype's handling of both arrays and dtypes. Now it handles just arrays, and we provide `is_extension_dtype` for checking whether a dtype is an extension dtype. It's the caller's responsibility to know whether they have an array or dtype. Closes pandas-dev#22021
1 parent b975455 commit 0a3994b

File tree

10 files changed

+80
-29
lines changed

10 files changed

+80
-29
lines changed

doc/source/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2501,6 +2501,8 @@ Dtype introspection
25012501
api.types.is_datetime64_ns_dtype
25022502
api.types.is_datetime64tz_dtype
25032503
api.types.is_extension_type
2504+
api.types.is_extension_array_dtype
2505+
api.types.is_extension_dtype
25042506
api.types.is_float_dtype
25052507
api.types.is_int64_dtype
25062508
api.types.is_integer_dtype

doc/source/whatsnew/v0.24.0.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,8 @@ ExtensionType Changes
323323
- Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`)
324324
- :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`)
325325
- :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`)
326-
-
326+
- Added :func:`pandas.api.types.is_extension_array_dtype` for testing whether an array is an ExtensionArray and :func:`pandas.api.types.is_extension_dtype` for testing whether a dtype is an ExtensionDtype (:issue:`22021`)
327+
327328

328329
.. _whatsnew_0240.api.incompatibilities:
329330

pandas/core/algorithms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
is_integer_dtype, is_complex_dtype,
2020
is_object_dtype,
2121
is_extension_array_dtype,
22+
is_extension_dtype,
2223
is_categorical_dtype, is_sparse,
2324
is_period_dtype,
2425
is_numeric_dtype, is_float_dtype,
@@ -153,7 +154,7 @@ def _reconstruct_data(values, dtype, original):
153154
Index for extension types, otherwise ndarray casted to dtype
154155
"""
155156
from pandas import Index
156-
if is_extension_array_dtype(dtype):
157+
if is_extension_dtype(dtype):
157158
values = dtype.construct_array_type()._from_sequence(values)
158159
elif is_datetime64tz_dtype(dtype) or is_period_dtype(dtype):
159160
values = Index(original)._shallow_copy(values, name=None)

pandas/core/dtypes/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from .common import (pandas_dtype,
66
is_dtype_equal,
7+
is_extension_dtype,
8+
is_extension_array_dtype,
79
is_extension_type,
810

911
# categorical

pandas/core/dtypes/cast.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
is_complex, is_datetimetz, is_categorical_dtype,
1313
is_datetimelike,
1414
is_extension_type,
15-
is_extension_array_dtype,
15+
is_extension_dtype,
1616
is_object_dtype,
1717
is_datetime64tz_dtype, is_datetime64_dtype,
1818
is_datetime64_ns_dtype,
@@ -294,7 +294,7 @@ def maybe_promote(dtype, fill_value=np.nan):
294294
elif is_datetimetz(dtype):
295295
if isna(fill_value):
296296
fill_value = iNaT
297-
elif is_extension_array_dtype(dtype) and isna(fill_value):
297+
elif is_extension_dtype(dtype) and isna(fill_value):
298298
fill_value = dtype.na_value
299299
elif is_float(fill_value):
300300
if issubclass(dtype.type, np.bool_):
@@ -332,7 +332,7 @@ def maybe_promote(dtype, fill_value=np.nan):
332332
dtype = np.object_
333333

334334
# in case we have a string that looked like a number
335-
if is_extension_array_dtype(dtype):
335+
if is_extension_dtype(dtype):
336336
pass
337337
elif is_datetimetz(dtype):
338338
pass
@@ -650,7 +650,7 @@ def astype_nansafe(arr, dtype, copy=True):
650650
need to be very careful as the result shape could change! """
651651

652652
# dispatch on extension dtype if needed
653-
if is_extension_array_dtype(dtype):
653+
if is_extension_dtype(dtype):
654654
return dtype.construct_array_type()._from_sequence(
655655
arr, dtype=dtype, copy=copy)
656656

pandas/core/dtypes/common.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,38 +1688,69 @@ def is_extension_type(arr):
16881688
return False
16891689

16901690

1691-
def is_extension_array_dtype(arr_or_dtype):
1692-
"""Check if an object is a pandas extension array type.
1691+
def is_extension_array_dtype(arr):
1692+
"""Check if an array object is a pandas extension array type.
16931693
16941694
Parameters
16951695
----------
1696-
arr_or_dtype : object
1696+
arr : object
16971697
16981698
Returns
16991699
-------
17001700
bool
17011701
17021702
Notes
17031703
-----
1704-
This checks whether an object implements the pandas extension
1704+
This checks whether an array object implements the pandas extension
17051705
array interface. In pandas, this includes:
17061706
17071707
* Categorical
1708+
* Interval
17081709
1709-
Third-party libraries may implement arrays or types satisfying
1710+
Third-party libraries may implement arrays satisfying
17101711
this interface as well.
1711-
"""
1712-
from pandas.core.arrays import ExtensionArray
17131712
1714-
if isinstance(arr_or_dtype, (ABCIndexClass, ABCSeries)):
1715-
arr_or_dtype = arr_or_dtype._values
1713+
See Also
1714+
--------
1715+
is_extension_dtype : Similar method for dtypes.
1716+
"""
1717+
from pandas.core.dtypes.base import ExtensionDtype
17161718

17171719
try:
1718-
arr_or_dtype = pandas_dtype(arr_or_dtype)
1719-
except TypeError:
1720-
pass
1720+
dtype = getattr(arr, 'dtype')
1721+
except AttributeError:
1722+
return False
1723+
1724+
return isinstance(dtype, ExtensionDtype)
17211725

1722-
return isinstance(arr_or_dtype, (ExtensionDtype, ExtensionArray))
1726+
1727+
def is_extension_dtype(dtype):
1728+
"""Check if a dtype object is a pandas extension dtype.
1729+
1730+
Parameters
1731+
----------
1732+
arr : object
1733+
1734+
Returns
1735+
-------
1736+
bool
1737+
1738+
Notes
1739+
-----
1740+
This checks whether a dtype object implements the pandas extension
1741+
array interface. In pandas, this includes:
1742+
1743+
* CategoricalDtype
1744+
* IntervalDtype
1745+
1746+
Third-party libraries may implement dtypes satisfying
1747+
this interface as well.
1748+
1749+
See Also
1750+
--------
1751+
is_extension_array_dtype : Similar method for arrays.
1752+
"""
1753+
return isinstance(dtype, ExtensionDtype)
17231754

17241755

17251756
def is_complex_dtype(arr_or_dtype):

pandas/core/indexes/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
is_datetime64tz_dtype,
4646
is_timedelta64_dtype,
4747
is_extension_array_dtype,
48+
is_extension_dtype,
4849
is_hashable,
4950
is_iterator, is_list_like,
5051
is_scalar)
@@ -275,7 +276,7 @@ def __new__(cls, data=None, dtype=None, copy=False, name=None,
275276
closed=closed)
276277

277278
# extension dtype
278-
elif is_extension_array_dtype(data) or is_extension_array_dtype(dtype):
279+
elif is_extension_array_dtype(data) or is_extension_dtype(dtype):
279280
data = np.asarray(data)
280281
if not (dtype is None or is_object_dtype(dtype)):
281282

@@ -1191,7 +1192,7 @@ def astype(self, dtype, copy=True):
11911192
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
11921193
copy=copy)
11931194

1194-
elif is_extension_array_dtype(dtype):
1195+
elif is_extension_dtype(dtype):
11951196
return Index(np.asarray(self), dtype=dtype, copy=copy)
11961197

11971198
try:

pandas/core/series.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
is_float_dtype,
2323
is_extension_type,
2424
is_extension_array_dtype,
25+
is_extension_dtype,
2526
is_datetimelike,
2627
is_datetime64tz_dtype,
2728
is_timedelta64_dtype,
@@ -4088,7 +4089,7 @@ def _try_cast(arr, take_fast_path):
40884089
# that Categorical is the only array type for 'category'.
40894090
subarr = Categorical(arr, dtype.categories,
40904091
ordered=dtype.ordered)
4091-
elif is_extension_array_dtype(dtype):
4092+
elif is_extension_dtype(dtype):
40924093
# create an extension array from its dtype
40934094
array_type = dtype.construct_array_type()
40944095
subarr = array_type(subarr, dtype=dtype, copy=copy)

pandas/tests/extension/base/interface.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import pandas as pd
44
from pandas.compat import StringIO
5-
from pandas.core.dtypes.common import is_extension_array_dtype
6-
from pandas.core.dtypes.dtypes import ExtensionDtype
5+
from pandas.api.types import (
6+
is_extension_array_dtype, is_extension_dtype
7+
)
8+
from pandas.api.extensions import ExtensionDtype
79

810
from .base import BaseExtensionTests
911

@@ -58,10 +60,12 @@ def test_dtype_name_in_info(self, data):
5860

5961
def test_is_extension_array_dtype(self, data):
6062
assert is_extension_array_dtype(data)
61-
assert is_extension_array_dtype(data.dtype)
6263
assert is_extension_array_dtype(pd.Series(data))
6364
assert isinstance(data.dtype, ExtensionDtype)
6465

66+
def test_is_extension_dtype(self, data):
67+
assert is_extension_dtype(data.dtype)
68+
6569
def test_no_values_attribute(self, data):
6670
# GH-20735: EA's with .values attribute give problems with internal
6771
# code, disallowing this for now until solved

pandas/tests/extension/test_common.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import pandas as pd
55
import pandas.util.testing as tm
66
from pandas.core.arrays import ExtensionArray
7-
from pandas.core.dtypes.common import is_extension_array_dtype
7+
from pandas.core.dtypes.common import (
8+
is_extension_array_dtype, is_extension_dtype
9+
)
810
from pandas.core.dtypes import dtypes
911

1012

@@ -38,14 +40,20 @@ class TestExtensionArrayDtype(object):
3840

3941
@pytest.mark.parametrize('values', [
4042
pd.Categorical([]),
41-
pd.Categorical([]).dtype,
4243
pd.Series(pd.Categorical([])),
43-
DummyDtype(),
44+
4445
DummyArray(np.array([1, 2])),
4546
])
4647
def test_is_extension_array_dtype(self, values):
4748
assert is_extension_array_dtype(values)
4849

50+
@pytest.mark.parametrize('dtype', [
51+
pd.Categorical([]).dtype,
52+
DummyDtype(),
53+
])
54+
def test_is_extension_dtype(self, dtype):
55+
assert is_extension_dtype(dtype)
56+
4957
@pytest.mark.parametrize('values', [
5058
np.array([]),
5159
pd.Series(np.array([])),
@@ -91,4 +99,4 @@ def test_is_not_extension_array_dtype(dtype):
9199
])
92100
def test_is_extension_array_dtype(dtype):
93101
assert isinstance(dtype, dtypes.ExtensionDtype)
94-
assert is_extension_array_dtype(dtype)
102+
assert is_extension_dtype(dtype)

0 commit comments

Comments
 (0)