Skip to content

Commit c2fc924

Browse files
authored
PERF: avoid copies in lib.infer_dtype (#45057)
1 parent 7ac3361 commit c2fc924

File tree

8 files changed

+94
-57
lines changed

8 files changed

+94
-57
lines changed

pandas/_libs/lib.pyx

+53-38
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ from pandas._libs.missing cimport (
9797
is_matching_na,
9898
is_null_datetime64,
9999
is_null_timedelta64,
100-
isnaobj,
101100
)
102101
from pandas._libs.tslibs.conversion cimport convert_to_tsobject
103102
from pandas._libs.tslibs.nattype cimport (
@@ -1454,6 +1453,7 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
14541453
ndarray values
14551454
bint seen_pdnat = False
14561455
bint seen_val = False
1456+
flatiter it
14571457

14581458
if util.is_array(value):
14591459
values = value
@@ -1491,24 +1491,22 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
14911491
# This should not be reached
14921492
values = values.astype(object)
14931493

1494-
# for f-contiguous array 1000 x 1000, passing order="K" gives 5000x speedup
1495-
values = values.ravel(order="K")
1496-
1497-
if skipna:
1498-
values = values[~isnaobj(values)]
1499-
15001494
n = cnp.PyArray_SIZE(values)
15011495
if n == 0:
15021496
return "empty"
15031497

15041498
# Iterate until we find our first valid value. We will use this
15051499
# value to decide which of the is_foo_array functions to call.
1500+
it = PyArray_IterNew(values)
15061501
for i in range(n):
1507-
val = values[i]
1502+
# The PyArray_GETITEM and PyArray_ITER_NEXT are faster
1503+
# equivalents to `val = values[i]`
1504+
val = PyArray_GETITEM(values, PyArray_ITER_DATA(it))
1505+
PyArray_ITER_NEXT(it)
15081506

15091507
# do not use checknull to keep
15101508
# np.datetime64('nat') and np.timedelta64('nat')
1511-
if val is None or util.is_nan(val):
1509+
if val is None or util.is_nan(val) or val is C_NA:
15121510
pass
15131511
elif val is NaT:
15141512
seen_pdnat = True
@@ -1520,23 +1518,25 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
15201518
if seen_val is False and seen_pdnat is True:
15211519
return "datetime"
15221520
# float/object nan is handled in latter logic
1521+
if seen_val is False and skipna:
1522+
return "empty"
15231523

15241524
if util.is_datetime64_object(val):
1525-
if is_datetime64_array(values):
1525+
if is_datetime64_array(values, skipna=skipna):
15261526
return "datetime64"
15271527

15281528
elif is_timedelta(val):
1529-
if is_timedelta_or_timedelta64_array(values):
1529+
if is_timedelta_or_timedelta64_array(values, skipna=skipna):
15301530
return "timedelta"
15311531

15321532
elif util.is_integer_object(val):
15331533
# ordering matters here; this check must come after the is_timedelta
15341534
# check otherwise numpy timedelta64 objects would come through here
15351535

1536-
if is_integer_array(values):
1536+
if is_integer_array(values, skipna=skipna):
15371537
return "integer"
1538-
elif is_integer_float_array(values):
1539-
if is_integer_na_array(values):
1538+
elif is_integer_float_array(values, skipna=skipna):
1539+
if is_integer_na_array(values, skipna=skipna):
15401540
return "integer-na"
15411541
else:
15421542
return "mixed-integer-float"
@@ -1557,7 +1557,7 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
15571557
return "time"
15581558

15591559
elif is_decimal(val):
1560-
if is_decimal_array(values):
1560+
if is_decimal_array(values, skipna=skipna):
15611561
return "decimal"
15621562

15631563
elif util.is_complex_object(val):
@@ -1567,8 +1567,8 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
15671567
elif util.is_float_object(val):
15681568
if is_float_array(values):
15691569
return "floating"
1570-
elif is_integer_float_array(values):
1571-
if is_integer_na_array(values):
1570+
elif is_integer_float_array(values, skipna=skipna):
1571+
if is_integer_na_array(values, skipna=skipna):
15721572
return "integer-na"
15731573
else:
15741574
return "mixed-integer-float"
@@ -1586,15 +1586,18 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
15861586
return "bytes"
15871587

15881588
elif is_period_object(val):
1589-
if is_period_array(values):
1589+
if is_period_array(values, skipna=skipna):
15901590
return "period"
15911591

15921592
elif is_interval(val):
15931593
if is_interval_array(values):
15941594
return "interval"
15951595

1596+
cnp.PyArray_ITER_RESET(it)
15961597
for i in range(n):
1597-
val = values[i]
1598+
val = PyArray_GETITEM(values, PyArray_ITER_DATA(it))
1599+
PyArray_ITER_NEXT(it)
1600+
15981601
if util.is_integer_object(val):
15991602
return "mixed-integer"
16001603

@@ -1823,10 +1826,11 @@ cdef class IntegerValidator(Validator):
18231826

18241827

18251828
# Note: only python-exposed for tests
1826-
cpdef bint is_integer_array(ndarray values):
1829+
cpdef bint is_integer_array(ndarray values, bint skipna=True):
18271830
cdef:
18281831
IntegerValidator validator = IntegerValidator(len(values),
1829-
values.dtype)
1832+
values.dtype,
1833+
skipna=skipna)
18301834
return validator.validate(values)
18311835

18321836

@@ -1837,10 +1841,10 @@ cdef class IntegerNaValidator(Validator):
18371841
or (util.is_nan(value) and util.is_float_object(value)))
18381842

18391843

1840-
cdef bint is_integer_na_array(ndarray values):
1844+
cdef bint is_integer_na_array(ndarray values, bint skipna=True):
18411845
cdef:
18421846
IntegerNaValidator validator = IntegerNaValidator(len(values),
1843-
values.dtype)
1847+
values.dtype, skipna=skipna)
18441848
return validator.validate(values)
18451849

18461850

@@ -1853,10 +1857,11 @@ cdef class IntegerFloatValidator(Validator):
18531857
return issubclass(self.dtype.type, np.integer)
18541858

18551859

1856-
cdef bint is_integer_float_array(ndarray values):
1860+
cdef bint is_integer_float_array(ndarray values, bint skipna=True):
18571861
cdef:
18581862
IntegerFloatValidator validator = IntegerFloatValidator(len(values),
1859-
values.dtype)
1863+
values.dtype,
1864+
skipna=skipna)
18601865
return validator.validate(values)
18611866

18621867

@@ -1900,9 +1905,11 @@ cdef class DecimalValidator(Validator):
19001905
return is_decimal(value)
19011906

19021907

1903-
cdef bint is_decimal_array(ndarray values):
1908+
cdef bint is_decimal_array(ndarray values, bint skipna=False):
19041909
cdef:
1905-
DecimalValidator validator = DecimalValidator(len(values), values.dtype)
1910+
DecimalValidator validator = DecimalValidator(
1911+
len(values), values.dtype, skipna=skipna
1912+
)
19061913
return validator.validate(values)
19071914

19081915

@@ -1997,10 +2004,10 @@ cdef class Datetime64Validator(DatetimeValidator):
19972004

19982005

19992006
# Note: only python-exposed for tests
2000-
cpdef bint is_datetime64_array(ndarray values):
2007+
cpdef bint is_datetime64_array(ndarray values, bint skipna=True):
20012008
cdef:
20022009
Datetime64Validator validator = Datetime64Validator(len(values),
2003-
skipna=True)
2010+
skipna=skipna)
20042011
return validator.validate(values)
20052012

20062013

@@ -2012,10 +2019,10 @@ cdef class AnyDatetimeValidator(DatetimeValidator):
20122019
)
20132020

20142021

2015-
cdef bint is_datetime_or_datetime64_array(ndarray values):
2022+
cdef bint is_datetime_or_datetime64_array(ndarray values, bint skipna=True):
20162023
cdef:
20172024
AnyDatetimeValidator validator = AnyDatetimeValidator(len(values),
2018-
skipna=True)
2025+
skipna=skipna)
20192026
return validator.validate(values)
20202027

20212028

@@ -2069,13 +2076,13 @@ cdef class AnyTimedeltaValidator(TimedeltaValidator):
20692076

20702077

20712078
# Note: only python-exposed for tests
2072-
cpdef bint is_timedelta_or_timedelta64_array(ndarray values):
2079+
cpdef bint is_timedelta_or_timedelta64_array(ndarray values, bint skipna=True):
20732080
"""
20742081
Infer with timedeltas and/or nat/none.
20752082
"""
20762083
cdef:
20772084
AnyTimedeltaValidator validator = AnyTimedeltaValidator(len(values),
2078-
skipna=True)
2085+
skipna=skipna)
20792086
return validator.validate(values)
20802087

20812088

@@ -2105,20 +2112,28 @@ cpdef bint is_time_array(ndarray values, bint skipna=False):
21052112
return validator.validate(values)
21062113

21072114

2108-
cdef bint is_period_array(ndarray[object] values):
2115+
# FIXME: actually use skipna
2116+
cdef bint is_period_array(ndarray values, bint skipna=True):
21092117
"""
21102118
Is this an ndarray of Period objects (or NaT) with a single `freq`?
21112119
"""
2120+
# values should be object-dtype, but ndarray[object] assumes 1D, while
2121+
# this _may_ be 2D.
21122122
cdef:
2113-
Py_ssize_t i, n = len(values)
2123+
Py_ssize_t i, N = values.size
21142124
int dtype_code = -10000 # i.e. c_FreqGroup.FR_UND
21152125
object val
2126+
flatiter it
21162127

2117-
if len(values) == 0:
2128+
if N == 0:
21182129
return False
21192130

2120-
for i in range(n):
2121-
val = values[i]
2131+
it = PyArray_IterNew(values)
2132+
for i in range(N):
2133+
# The PyArray_GETITEM and PyArray_ITER_NEXT are faster
2134+
# equivalents to `val = values[i]`
2135+
val = PyArray_GETITEM(values, PyArray_ITER_DATA(it))
2136+
PyArray_ITER_NEXT(it)
21222137

21232138
if is_period_object(val):
21242139
if dtype_code == -10000:

pandas/conftest.py

+1
Original file line numberDiff line numberDiff line change
@@ -1549,6 +1549,7 @@ def any_numpy_dtype(request):
15491549
_any_skipna_inferred_dtype = [
15501550
("string", ["a", np.nan, "c"]),
15511551
("string", ["a", pd.NA, "c"]),
1552+
("mixed", ["a", pd.NaT, "c"]), # pd.NaT not considered valid by is_string_array
15521553
("bytes", [b"a", np.nan, b"c"]),
15531554
("empty", [np.nan, np.nan, np.nan]),
15541555
("empty", []),

pandas/core/arrays/floating.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,7 @@ def coerce_to_array(
116116
inferred_type = lib.infer_dtype(values, skipna=True)
117117
if inferred_type == "empty":
118118
pass
119-
elif inferred_type not in [
120-
"floating",
121-
"integer",
122-
"mixed-integer",
123-
"integer-na",
124-
"mixed-integer-float",
125-
]:
119+
elif inferred_type == "boolean":
126120
raise TypeError(f"{values.dtype} cannot be converted to a FloatingDtype")
127121

128122
elif is_bool_dtype(values) and is_float_dtype(dtype):

pandas/core/arrays/integer.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -166,16 +166,8 @@ def coerce_to_array(
166166
inferred_type = lib.infer_dtype(values, skipna=True)
167167
if inferred_type == "empty":
168168
pass
169-
elif inferred_type not in [
170-
"floating",
171-
"integer",
172-
"mixed-integer",
173-
"integer-na",
174-
"mixed-integer-float",
175-
"string",
176-
"unicode",
177-
]:
178-
raise TypeError(f"{values.dtype} cannot be converted to an IntegerDtype")
169+
elif inferred_type == "boolean":
170+
raise TypeError(f"{values.dtype} cannot be converted to a FloatingDtype")
179171

180172
elif is_bool_dtype(values) and is_integer_dtype(dtype):
181173
values = np.array(values, dtype=int, copy=copy)

pandas/tests/arrays/floating/test_construction.py

+1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def test_to_array_error(values):
131131
"cannot be converted to a FloatingDtype",
132132
"values must be a 1D list-like",
133133
"Cannot pass scalar",
134+
r"float\(\) argument must be a string or a (real )?number, not 'dict'",
134135
]
135136
)
136137
with pytest.raises((TypeError, ValueError), match=msg):

pandas/tests/arrays/integer/test_construction.py

+1
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def test_to_integer_array_error(values):
139139
r"invalid literal for int\(\) with base 10:",
140140
r"values must be a 1D list-like",
141141
r"Cannot pass scalar",
142+
r"int\(\) argument must be a string",
142143
]
143144
)
144145
with pytest.raises((ValueError, TypeError), match=msg):

pandas/tests/arrays/string_/test_string.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,26 @@ def test_constructor_raises(cls):
266266
with pytest.raises(ValueError, match=msg):
267267
cls(np.array([]))
268268

269-
with pytest.raises(ValueError, match=msg):
270-
cls(np.array(["a", np.datetime64("nat")], dtype=object))
269+
if cls is pd.arrays.StringArray:
270+
# GH#45057 np.nan and None do NOT raise, as they are considered valid NAs
271+
# for string dtype
272+
cls(np.array(["a", np.nan], dtype=object))
273+
cls(np.array(["a", None], dtype=object))
274+
else:
275+
with pytest.raises(ValueError, match=msg):
276+
cls(np.array(["a", np.nan], dtype=object))
277+
with pytest.raises(ValueError, match=msg):
278+
cls(np.array(["a", None], dtype=object))
271279

272280
with pytest.raises(ValueError, match=msg):
273281
cls(np.array(["a", pd.NaT], dtype=object))
274282

283+
with pytest.raises(ValueError, match=msg):
284+
cls(np.array(["a", np.datetime64("NaT", "ns")], dtype=object))
285+
286+
with pytest.raises(ValueError, match=msg):
287+
cls(np.array(["a", np.timedelta64("NaT", "ns")], dtype=object))
288+
275289

276290
@pytest.mark.parametrize("na", [np.nan, np.float64("nan"), float("nan"), None, pd.NA])
277291
def test_constructor_nan_like(na):

pandas/tests/dtypes/test_inference.py

+19
Original file line numberDiff line numberDiff line change
@@ -1134,10 +1134,20 @@ def test_unicode(self):
11341134
# This could also return "string" or "mixed-string"
11351135
assert result == "mixed"
11361136

1137+
# even though we use skipna, we are only skipping those NAs that are
1138+
# considered matching by is_string_array
11371139
arr = ["a", np.nan, "c"]
11381140
result = lib.infer_dtype(arr, skipna=True)
11391141
assert result == "string"
11401142

1143+
arr = ["a", pd.NA, "c"]
1144+
result = lib.infer_dtype(arr, skipna=True)
1145+
assert result == "string"
1146+
1147+
arr = ["a", pd.NaT, "c"]
1148+
result = lib.infer_dtype(arr, skipna=True)
1149+
assert result == "mixed"
1150+
11411151
arr = ["a", "c"]
11421152
result = lib.infer_dtype(arr, skipna=False)
11431153
assert result == "string"
@@ -1544,15 +1554,24 @@ def test_is_string_array(self):
15441554
assert lib.is_string_array(
15451555
np.array(["foo", "bar", pd.NA], dtype=object), skipna=True
15461556
)
1557+
# we allow NaN/None in the StringArray constructor, so its allowed here
15471558
assert lib.is_string_array(
15481559
np.array(["foo", "bar", None], dtype=object), skipna=True
15491560
)
15501561
assert lib.is_string_array(
15511562
np.array(["foo", "bar", np.nan], dtype=object), skipna=True
15521563
)
1564+
# But not e.g. datetimelike or Decimal NAs
15531565
assert not lib.is_string_array(
15541566
np.array(["foo", "bar", pd.NaT], dtype=object), skipna=True
15551567
)
1568+
assert not lib.is_string_array(
1569+
np.array(["foo", "bar", np.datetime64("NaT")], dtype=object), skipna=True
1570+
)
1571+
assert not lib.is_string_array(
1572+
np.array(["foo", "bar", Decimal("NaN")], dtype=object), skipna=True
1573+
)
1574+
15561575
assert not lib.is_string_array(
15571576
np.array(["foo", "bar", None], dtype=object), skipna=False
15581577
)

0 commit comments

Comments
 (0)