From 47423b1fbece664828627caf878752357caa5169 Mon Sep 17 00:00:00 2001 From: Brock Date: Sat, 17 Apr 2021 22:08:31 -0700 Subject: [PATCH 1/2] REF: do casting closer to libjoin.foo_indexer --- pandas/core/indexes/base.py | 93 +++++++++++++++--------- pandas/core/indexes/datetimelike.py | 57 ++++----------- pandas/core/indexes/extension.py | 27 ++++++- pandas/core/indexes/multi.py | 6 +- pandas/tests/indexes/period/test_join.py | 2 +- 5 files changed, 101 insertions(+), 84 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 310ee4c3a63e3..c140d7dc8ba0c 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -302,23 +302,47 @@ class Index(IndexOpsMixin, PandasObject): # for why we need to wrap these instead of making them class attributes # Moreover, cython will choose the appropriate-dtyped sub-function # given the dtypes of the passed arguments - def _left_indexer_unique(self, left: np.ndarray, right: np.ndarray) -> np.ndarray: - return libjoin.left_join_indexer_unique(left, right) + @final + def _left_indexer_unique(self: _IndexT, other: _IndexT) -> np.ndarray: + # -> np.ndarray[np.intp] + # Caller is responsible for ensuring other.dtype == self.dtype + sv = self._get_join_target() + ov = other._get_join_target() + return libjoin.left_join_indexer_unique(sv, ov) + + @final def _left_indexer( - self, left: np.ndarray, right: np.ndarray - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - return libjoin.left_join_indexer(left, right) + self: _IndexT, other: _IndexT + ) -> tuple[ArrayLike, np.ndarray, np.ndarray]: + # Caller is responsible for ensuring other.dtype == self.dtype + sv = self._get_join_target() + ov = other._get_join_target() + joined_ndarray, lidx, ridx = libjoin.left_join_indexer(sv, ov) + joined = self._from_join_target(joined_ndarray) + return joined, lidx, ridx + @final def _inner_indexer( - self, left: np.ndarray, right: np.ndarray - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - return libjoin.inner_join_indexer(left, right) + self: _IndexT, other: _IndexT + ) -> tuple[ArrayLike, np.ndarray, np.ndarray]: + # Caller is responsible for ensuring other.dtype == self.dtype + sv = self._get_join_target() + ov = other._get_join_target() + joined_ndarray, lidx, ridx = libjoin.inner_join_indexer(sv, ov) + joined = self._from_join_target(joined_ndarray) + return joined, lidx, ridx + @final def _outer_indexer( - self, left: np.ndarray, right: np.ndarray - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - return libjoin.outer_join_indexer(left, right) + self: _IndexT, other: _IndexT + ) -> tuple[ArrayLike, np.ndarray, np.ndarray]: + # Caller is responsible for ensuring other.dtype == self.dtype + sv = self._get_join_target() + ov = other._get_join_target() + joined_ndarray, lidx, ridx = libjoin.outer_join_indexer(sv, ov) + joined = self._from_join_target(joined_ndarray) + return joined, lidx, ridx _typ = "index" _data: ExtensionArray | np.ndarray @@ -2965,11 +2989,7 @@ def _union(self, other: Index, sort): ): # Both are unique and monotonic, so can use outer join try: - # error: Argument 1 to "_outer_indexer" of "Index" has incompatible type - # "Union[ExtensionArray, ndarray]"; expected "ndarray" - # error: Argument 2 to "_outer_indexer" of "Index" has incompatible type - # "Union[ExtensionArray, ndarray]"; expected "ndarray" - return self._outer_indexer(lvals, rvals)[0] # type: ignore[arg-type] + return self._outer_indexer(other)[0] except (TypeError, IncompatibleFrequency): # incomparable objects value_list = list(lvals) @@ -3090,13 +3110,10 @@ def _intersection(self, other: Index, sort=False): """ # TODO(EA): setops-refactor, clean all this up lvals = self._values - rvals = other._values if self.is_monotonic and other.is_monotonic: try: - # error: Argument 1 to "_inner_indexer" of "Index" has incompatible type - # "Union[ExtensionArray, ndarray]"; expected "ndarray" - result = self._inner_indexer(lvals, rvals)[0] # type: ignore[arg-type] + result = self._inner_indexer(other)[0] except TypeError: pass else: @@ -4267,9 +4284,6 @@ def _join_monotonic(self, other: Index, how="left"): ret_index = other if how == "right" else self return ret_index, None, None - sv = self._get_engine_target() - ov = other._get_engine_target() - ridx: np.ndarray | None lidx: np.ndarray | None @@ -4278,26 +4292,26 @@ def _join_monotonic(self, other: Index, how="left"): if how == "left": join_index = self lidx = None - ridx = self._left_indexer_unique(sv, ov) + ridx = self._left_indexer_unique(other) elif how == "right": join_index = other - lidx = self._left_indexer_unique(ov, sv) + lidx = other._left_indexer_unique(self) ridx = None elif how == "inner": - join_array, lidx, ridx = self._inner_indexer(sv, ov) + join_array, lidx, ridx = self._inner_indexer(other) join_index = self._wrap_joined_index(join_array, other) elif how == "outer": - join_array, lidx, ridx = self._outer_indexer(sv, ov) + join_array, lidx, ridx = self._outer_indexer(other) join_index = self._wrap_joined_index(join_array, other) else: if how == "left": - join_array, lidx, ridx = self._left_indexer(sv, ov) + join_array, lidx, ridx = self._left_indexer(other) elif how == "right": - join_array, ridx, lidx = self._left_indexer(ov, sv) + join_array, ridx, lidx = other._left_indexer(self) elif how == "inner": - join_array, lidx, ridx = self._inner_indexer(sv, ov) + join_array, lidx, ridx = self._inner_indexer(other) elif how == "outer": - join_array, lidx, ridx = self._outer_indexer(sv, ov) + join_array, lidx, ridx = self._outer_indexer(other) join_index = self._wrap_joined_index(join_array, other) @@ -4305,9 +4319,7 @@ def _join_monotonic(self, other: Index, how="left"): ridx = None if ridx is None else ensure_platform_int(ridx) return join_index, lidx, ridx - def _wrap_joined_index( - self: _IndexT, joined: np.ndarray, other: _IndexT - ) -> _IndexT: + def _wrap_joined_index(self: _IndexT, joined: ArrayLike, other: _IndexT) -> _IndexT: assert other.dtype == self.dtype if isinstance(self, ABCMultiIndex): @@ -4385,6 +4397,19 @@ def _get_engine_target(self) -> np.ndarray: # ndarray]", expected "ndarray") return self._values # type: ignore[return-value] + def _get_join_target(self) -> np.ndarray: + """ + Get the ndarray that we will pass to libjoin functions. + """ + return self._get_engine_target() + + def _from_join_target(self, result: np.ndarray) -> ArrayLike: + """ + Cast the ndarray returned from one of the libjoin.foo_indexer functions + back to type(self)._data. + """ + return result + @doc(IndexOpsMixin._memory_usage) def memory_usage(self, deep: bool = False) -> int: result = self._memory_usage(deep=deep) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 7bc0655ea9529..adf80e92e9cba 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -20,7 +20,6 @@ NaT, Timedelta, iNaT, - join as libjoin, lib, ) from pandas._libs.tslibs import ( @@ -29,7 +28,10 @@ Resolution, Tick, ) -from pandas._typing import Callable +from pandas._typing import ( + ArrayLike, + Callable, +) from pandas.compat.numpy import function as nv from pandas.util._decorators import ( Appender, @@ -75,36 +77,6 @@ _T = TypeVar("_T", bound="DatetimeIndexOpsMixin") -def _join_i8_wrapper(joinf, with_indexers: bool = True): - """ - Create the join wrapper methods. - """ - - # error: 'staticmethod' used with a non-method - @staticmethod # type: ignore[misc] - def wrapper(left, right): - # Note: these only get called with left.dtype == right.dtype - orig_left = left - - left = left.view("i8") - right = right.view("i8") - - results = joinf(left, right) - if with_indexers: - - join_index, left_indexer, right_indexer = results - if not isinstance(orig_left, np.ndarray): - # When called from Index._intersection/_union, we have the EA - join_index = join_index.view(orig_left._ndarray.dtype) - join_index = orig_left._from_backing_data(join_index) - - return join_index, left_indexer, right_indexer - - return results - - return wrapper - - @inherit_names( ["inferred_freq", "_resolution_obj", "resolution"], DatetimeLikeArrayMixin, @@ -603,13 +575,6 @@ def insert(self, loc: int, item): # -------------------------------------------------------------------- # Join/Set Methods - _inner_indexer = _join_i8_wrapper(libjoin.inner_join_indexer) - _outer_indexer = _join_i8_wrapper(libjoin.outer_join_indexer) - _left_indexer = _join_i8_wrapper(libjoin.left_join_indexer) - _left_indexer_unique = _join_i8_wrapper( - libjoin.left_join_indexer_unique, with_indexers=False - ) - def _get_join_freq(self, other): """ Get the freq to attach to the result of a join operation. @@ -621,14 +586,22 @@ def _get_join_freq(self, other): freq = self.freq if self._can_fast_union(other) else None return freq - def _wrap_joined_index(self, joined: np.ndarray, other): + def _wrap_joined_index(self, joined: ArrayLike, other): assert other.dtype == self.dtype, (other.dtype, self.dtype) - assert joined.dtype == "i8" or joined.dtype == self.dtype, joined.dtype - joined = joined.view(self._data._ndarray.dtype) result = super()._wrap_joined_index(joined, other) result._data._freq = self._get_join_freq(other) return result + def _get_join_target(self) -> np.ndarray: + return self._data._ndarray.view("i8") + + def _from_join_target(self, result: np.ndarray): + # view e.g. i8 back to M8[ns] + result = result.view(self._data._ndarray.dtype) + return self._data._from_backing_data(result) + + # -------------------------------------------------------------------- + @doc(Index._convert_arr_indexer) def _convert_arr_indexer(self, keyarr): try: diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py index b11ec06120e0c..98439696f196f 100644 --- a/pandas/core/indexes/extension.py +++ b/pandas/core/indexes/extension.py @@ -11,6 +11,7 @@ import numpy as np +from pandas._typing import ArrayLike from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError from pandas.util._decorators import ( @@ -300,6 +301,11 @@ def searchsorted(self, value, side="left", sorter=None) -> np.ndarray: def _get_engine_target(self) -> np.ndarray: return np.asarray(self._data) + def _from_join_target(self, result: np.ndarray) -> ArrayLike: + # ATM this is only for IntervalIndex, implicit assumption + # about _get_engine_target + return type(self._data)._from_sequence(result, dtype=self.dtype) + def delete(self, loc): """ Make new Index with passed location(-s) deleted @@ -410,6 +416,10 @@ def _simple_new( def _get_engine_target(self) -> np.ndarray: return self._data._ndarray + def _from_join_target(self, result: np.ndarray) -> ArrayLike: + assert result.dtype == self._data._ndarray.dtype + return self._data._from_backing_data(result) + def insert(self: _T, loc: int, item) -> Index: """ Make new Index inserting new item at location. Follows @@ -458,7 +468,18 @@ def putmask(self, mask, value) -> Index: return type(self)._simple_new(res_values, name=self.name) - def _wrap_joined_index(self: _T, joined: np.ndarray, other: _T) -> _T: + def _wrap_joined_index(self: _T, joined: ArrayLike, other: _T) -> _T: name = get_op_result_name(self, other) - arr = self._data._from_backing_data(joined) - return type(self)._simple_new(arr, name=name) + + if isinstance(joined, np.ndarray): + # Reached from _join_non_unique, view back e.g. i8 to M8ns/m8ns + joined = joined.view(self._data._ndarray.dtype) + joined_ea = self._data._from_backing_data(joined) + else: + # error: Incompatible types in assignment (expression has type + # "ExtensionArray", variable has type "NDArrayBackedExtensionArray") + joined_ea = joined # type: ignore[assignment] + assert type(joined) is type(self._data) # noqa: E721 + assert joined.dtype == self.dtype + + return type(self)._simple_new(joined_ea, name=name) diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 59ff128713aca..794f13bbfb6b1 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3613,14 +3613,12 @@ def _maybe_match_names(self, other): def _intersection(self, other, sort=False) -> MultiIndex: other, result_names = self._convert_can_do_setop(other) - - lvals = self._values - rvals = other._values.astype(object, copy=False) + other = other.astype(object, copy=False) uniq_tuples = None # flag whether _inner_indexer was successful if self.is_monotonic and other.is_monotonic: try: - inner_tuples = self._inner_indexer(lvals, rvals)[0] + inner_tuples = self._inner_indexer(other)[0] sort = False # inner_tuples is already sorted except TypeError: pass diff --git a/pandas/tests/indexes/period/test_join.py b/pandas/tests/indexes/period/test_join.py index 77dcd38b239ec..b8b15708466cb 100644 --- a/pandas/tests/indexes/period/test_join.py +++ b/pandas/tests/indexes/period/test_join.py @@ -15,7 +15,7 @@ class TestJoin: def test_join_outer_indexer(self): pi = period_range("1/1/2000", "1/20/2000", freq="D") - result = pi._outer_indexer(pi._values, pi._values) + result = pi._outer_indexer(pi) tm.assert_extension_array_equal(result[0], pi._values) tm.assert_numpy_array_equal(result[1], np.arange(len(pi), dtype=np.intp)) tm.assert_numpy_array_equal(result[2], np.arange(len(pi), dtype=np.intp)) From 46f2e257751138f92bccb5aaeef5027eafa49a37 Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 18 Apr 2021 09:47:59 -0700 Subject: [PATCH 2/2] Avoid special-casing in wrap_joined_index --- pandas/core/indexes/base.py | 7 ++++--- pandas/core/indexes/datetimelike.py | 7 ++----- pandas/core/indexes/extension.py | 19 ++++++------------- 3 files changed, 12 insertions(+), 21 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index c140d7dc8ba0c..7f969ea5a26af 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -4112,8 +4112,8 @@ def _join_non_unique(self, other, how="left"): # We only get here if dtypes match assert self.dtype == other.dtype - lvalues = self._get_engine_target() - rvalues = other._get_engine_target() + lvalues = self._get_join_target() + rvalues = other._get_join_target() left_idx, right_idx = get_join_indexers( [lvalues], [rvalues], how=how, sort=True @@ -4126,7 +4126,8 @@ def _join_non_unique(self, other, how="left"): mask = left_idx == -1 np.putmask(join_array, mask, rvalues.take(right_idx)) - join_index = self._wrap_joined_index(join_array, other) + join_arraylike = self._from_join_target(join_array) + join_index = self._wrap_joined_index(join_arraylike, other) return join_index, left_idx, right_idx diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index adf80e92e9cba..b2d2c98c08f68 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -28,10 +28,7 @@ Resolution, Tick, ) -from pandas._typing import ( - ArrayLike, - Callable, -) +from pandas._typing import Callable from pandas.compat.numpy import function as nv from pandas.util._decorators import ( Appender, @@ -586,7 +583,7 @@ def _get_join_freq(self, other): freq = self.freq if self._can_fast_union(other) else None return freq - def _wrap_joined_index(self, joined: ArrayLike, other): + def _wrap_joined_index(self, joined, other): assert other.dtype == self.dtype, (other.dtype, self.dtype) result = super()._wrap_joined_index(joined, other) result._data._freq = self._get_join_freq(other) diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py index 98439696f196f..d593ddc640967 100644 --- a/pandas/core/indexes/extension.py +++ b/pandas/core/indexes/extension.py @@ -468,18 +468,11 @@ def putmask(self, mask, value) -> Index: return type(self)._simple_new(res_values, name=self.name) - def _wrap_joined_index(self: _T, joined: ArrayLike, other: _T) -> _T: + # error: Argument 1 of "_wrap_joined_index" is incompatible with supertype + # "Index"; supertype defines the argument type as "Union[ExtensionArray, ndarray]" + def _wrap_joined_index( # type: ignore[override] + self: _T, joined: NDArrayBackedExtensionArray, other: _T + ) -> _T: name = get_op_result_name(self, other) - if isinstance(joined, np.ndarray): - # Reached from _join_non_unique, view back e.g. i8 to M8ns/m8ns - joined = joined.view(self._data._ndarray.dtype) - joined_ea = self._data._from_backing_data(joined) - else: - # error: Incompatible types in assignment (expression has type - # "ExtensionArray", variable has type "NDArrayBackedExtensionArray") - joined_ea = joined # type: ignore[assignment] - assert type(joined) is type(self._data) # noqa: E721 - assert joined.dtype == self.dtype - - return type(self)._simple_new(joined_ea, name=name) + return type(self)._simple_new(joined, name=name)