Skip to content

Commit dea0195

Browse files
authored
REF: ExtensionIndex.searchsorted -> IndexOpsMixin.searchsorted (#43653)
1 parent 4b54c53 commit dea0195

File tree

4 files changed

+48
-52
lines changed

4 files changed

+48
-52
lines changed

pandas/core/algorithms.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
)
9797
from pandas.core.arrays import (
9898
DatetimeArray,
99+
ExtensionArray,
99100
TimedeltaArray,
100101
)
101102

@@ -1535,7 +1536,7 @@ def take(
15351536

15361537
def searchsorted(
15371538
arr: ArrayLike,
1538-
value: NumpyValueArrayLike,
1539+
value: NumpyValueArrayLike | ExtensionArray,
15391540
side: Literal["left", "right"] = "left",
15401541
sorter: NumpySorter = None,
15411542
) -> npt.NDArray[np.intp] | np.intp:
@@ -1616,7 +1617,9 @@ def searchsorted(
16161617
# and `value` is a pd.Timestamp, we may need to convert value
16171618
arr = ensure_wrapped_if_datetimelike(arr)
16181619

1619-
return arr.searchsorted(value, side=side, sorter=sorter)
1620+
# Argument 1 to "searchsorted" of "ndarray" has incompatible type
1621+
# "Union[NumpyValueArrayLike, ExtensionArray]"; expected "NumpyValueArrayLike"
1622+
return arr.searchsorted(value, side=side, sorter=sorter) # type: ignore[arg-type]
16201623

16211624

16221625
# ---- #

pandas/core/base.py

+39-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
TypeVar,
1515
cast,
1616
final,
17+
overload,
1718
)
1819

1920
import numpy as np
@@ -1230,14 +1231,50 @@ def factorize(self, sort: bool = False, na_sentinel: int | None = -1):
12301231
0 # wrong result, correct would be 1
12311232
"""
12321233

1234+
# This overload is needed so that the call to searchsorted in
1235+
# pandas.core.resample.TimeGrouper._get_period_bins picks the correct result
1236+
1237+
@overload
1238+
# The following ignore is also present in numpy/__init__.pyi
1239+
# Possibly a mypy bug??
1240+
# error: Overloaded function signatures 1 and 2 overlap with incompatible
1241+
# return types [misc]
1242+
def searchsorted( # type: ignore[misc]
1243+
self,
1244+
value: npt._ScalarLike_co,
1245+
side: Literal["left", "right"] = "left",
1246+
sorter: NumpySorter = None,
1247+
) -> np.intp:
1248+
...
1249+
1250+
@overload
1251+
def searchsorted(
1252+
self,
1253+
value: npt.ArrayLike | ExtensionArray,
1254+
side: Literal["left", "right"] = "left",
1255+
sorter: NumpySorter = None,
1256+
) -> npt.NDArray[np.intp]:
1257+
...
1258+
12331259
@doc(_shared_docs["searchsorted"], klass="Index")
12341260
def searchsorted(
12351261
self,
1236-
value: NumpyValueArrayLike,
1262+
value: NumpyValueArrayLike | ExtensionArray,
12371263
side: Literal["left", "right"] = "left",
12381264
sorter: NumpySorter = None,
12391265
) -> npt.NDArray[np.intp] | np.intp:
1240-
return algorithms.searchsorted(self._values, value, side=side, sorter=sorter)
1266+
1267+
values = self._values
1268+
if not isinstance(values, np.ndarray):
1269+
# Going through EA.searchsorted directly improves performance GH#38083
1270+
return values.searchsorted(value, side=side, sorter=sorter)
1271+
1272+
return algorithms.searchsorted(
1273+
values,
1274+
value,
1275+
side=side,
1276+
sorter=sorter,
1277+
)
12411278

12421279
def drop_duplicates(self, keep="first"):
12431280
duplicated = self._duplicated(keep=keep)

pandas/core/indexes/extension.py

-45
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@
44
from __future__ import annotations
55

66
from typing import (
7-
TYPE_CHECKING,
87
Hashable,
9-
Literal,
108
TypeVar,
11-
overload,
129
)
1310

1411
import numpy as np
@@ -38,17 +35,9 @@
3835
TimedeltaArray,
3936
)
4037
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
41-
from pandas.core.arrays.base import ExtensionArray
4238
from pandas.core.indexers import deprecate_ndim_indexing
4339
from pandas.core.indexes.base import Index
4440

45-
if TYPE_CHECKING:
46-
47-
from pandas._typing import (
48-
NumpySorter,
49-
NumpyValueArrayLike,
50-
)
51-
5241
_T = TypeVar("_T", bound="NDArrayBackedExtensionIndex")
5342

5443

@@ -207,40 +196,6 @@ def __getitem__(self, key):
207196
deprecate_ndim_indexing(result)
208197
return result
209198

210-
# This overload is needed so that the call to searchsorted in
211-
# pandas.core.resample.TimeGrouper._get_period_bins picks the correct result
212-
213-
@overload
214-
# The following ignore is also present in numpy/__init__.pyi
215-
# Possibly a mypy bug??
216-
# error: Overloaded function signatures 1 and 2 overlap with incompatible
217-
# return types [misc]
218-
def searchsorted( # type: ignore[misc]
219-
self,
220-
value: npt._ScalarLike_co,
221-
side: Literal["left", "right"] = "left",
222-
sorter: NumpySorter = None,
223-
) -> np.intp:
224-
...
225-
226-
@overload
227-
def searchsorted(
228-
self,
229-
value: npt.ArrayLike | ExtensionArray,
230-
side: Literal["left", "right"] = "left",
231-
sorter: NumpySorter = None,
232-
) -> npt.NDArray[np.intp]:
233-
...
234-
235-
def searchsorted(
236-
self,
237-
value: NumpyValueArrayLike | ExtensionArray,
238-
side: Literal["left", "right"] = "left",
239-
sorter: NumpySorter = None,
240-
) -> npt.NDArray[np.intp] | np.intp:
241-
# overriding IndexOpsMixin improves performance GH#38083
242-
return self._data.searchsorted(value, side=side, sorter=sorter)
243-
244199
# ---------------------------------------------------------------------
245200

246201
def delete(self, loc):

pandas/core/series.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2790,13 +2790,14 @@ def __rmatmul__(self, other):
27902790
return self.dot(np.transpose(other))
27912791

27922792
@doc(base.IndexOpsMixin.searchsorted, klass="Series")
2793-
def searchsorted(
2793+
# Signature of "searchsorted" incompatible with supertype "IndexOpsMixin"
2794+
def searchsorted( # type: ignore[override]
27942795
self,
2795-
value: NumpyValueArrayLike,
2796+
value: NumpyValueArrayLike | ExtensionArray,
27962797
side: Literal["left", "right"] = "left",
27972798
sorter: NumpySorter = None,
27982799
) -> npt.NDArray[np.intp] | np.intp:
2799-
return algorithms.searchsorted(self._values, value, side=side, sorter=sorter)
2800+
return base.IndexOpsMixin.searchsorted(self, value, side=side, sorter=sorter)
28002801

28012802
# -------------------------------------------------------------------
28022803
# Combination

0 commit comments

Comments
 (0)