Skip to content

REF: pass arguments to Index._foo_indexer correctly #41024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 63 additions & 37 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,23 +302,47 @@ class Index(IndexOpsMixin, PandasObject):
# for why we need to wrap these instead of making them class attributes
# Moreover, cython will choose the appropriate-dtyped sub-function
# given the dtypes of the passed arguments
def _left_indexer_unique(self, left: np.ndarray, right: np.ndarray) -> np.ndarray:
return libjoin.left_join_indexer_unique(left, right)

@final
def _left_indexer_unique(self: _IndexT, other: _IndexT) -> np.ndarray:
# -> np.ndarray[np.intp]
# Caller is responsible for ensuring other.dtype == self.dtype
sv = self._get_join_target()
ov = other._get_join_target()
return libjoin.left_join_indexer_unique(sv, ov)

@final
def _left_indexer(
self, left: np.ndarray, right: np.ndarray
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
return libjoin.left_join_indexer(left, right)
self: _IndexT, other: _IndexT
) -> tuple[ArrayLike, np.ndarray, np.ndarray]:
# Caller is responsible for ensuring other.dtype == self.dtype
sv = self._get_join_target()
ov = other._get_join_target()
joined_ndarray, lidx, ridx = libjoin.left_join_indexer(sv, ov)
joined = self._from_join_target(joined_ndarray)
return joined, lidx, ridx

@final
def _inner_indexer(
self, left: np.ndarray, right: np.ndarray
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
return libjoin.inner_join_indexer(left, right)
self: _IndexT, other: _IndexT
) -> tuple[ArrayLike, np.ndarray, np.ndarray]:
# Caller is responsible for ensuring other.dtype == self.dtype
sv = self._get_join_target()
ov = other._get_join_target()
joined_ndarray, lidx, ridx = libjoin.inner_join_indexer(sv, ov)
joined = self._from_join_target(joined_ndarray)
return joined, lidx, ridx

@final
def _outer_indexer(
self, left: np.ndarray, right: np.ndarray
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
return libjoin.outer_join_indexer(left, right)
self: _IndexT, other: _IndexT
) -> tuple[ArrayLike, np.ndarray, np.ndarray]:
# Caller is responsible for ensuring other.dtype == self.dtype
sv = self._get_join_target()
ov = other._get_join_target()
joined_ndarray, lidx, ridx = libjoin.outer_join_indexer(sv, ov)
joined = self._from_join_target(joined_ndarray)
return joined, lidx, ridx

_typ = "index"
_data: ExtensionArray | np.ndarray
Expand Down Expand Up @@ -2965,11 +2989,7 @@ def _union(self, other: Index, sort):
):
# Both are unique and monotonic, so can use outer join
try:
# error: Argument 1 to "_outer_indexer" of "Index" has incompatible type
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
# error: Argument 2 to "_outer_indexer" of "Index" has incompatible type
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
return self._outer_indexer(lvals, rvals)[0] # type: ignore[arg-type]
return self._outer_indexer(other)[0]
except (TypeError, IncompatibleFrequency):
# incomparable objects
value_list = list(lvals)
Expand Down Expand Up @@ -3090,13 +3110,10 @@ def _intersection(self, other: Index, sort=False):
"""
# TODO(EA): setops-refactor, clean all this up
lvals = self._values
rvals = other._values

if self.is_monotonic and other.is_monotonic:
try:
# error: Argument 1 to "_inner_indexer" of "Index" has incompatible type
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
result = self._inner_indexer(lvals, rvals)[0] # type: ignore[arg-type]
result = self._inner_indexer(other)[0]
except TypeError:
pass
else:
Expand Down Expand Up @@ -4095,8 +4112,8 @@ def _join_non_unique(self, other, how="left"):
# We only get here if dtypes match
assert self.dtype == other.dtype

lvalues = self._get_engine_target()
rvalues = other._get_engine_target()
lvalues = self._get_join_target()
rvalues = other._get_join_target()

left_idx, right_idx = get_join_indexers(
[lvalues], [rvalues], how=how, sort=True
Expand All @@ -4109,7 +4126,8 @@ def _join_non_unique(self, other, how="left"):
mask = left_idx == -1
np.putmask(join_array, mask, rvalues.take(right_idx))

join_index = self._wrap_joined_index(join_array, other)
join_arraylike = self._from_join_target(join_array)
join_index = self._wrap_joined_index(join_arraylike, other)

return join_index, left_idx, right_idx

Expand Down Expand Up @@ -4267,9 +4285,6 @@ def _join_monotonic(self, other: Index, how="left"):
ret_index = other if how == "right" else self
return ret_index, None, None

sv = self._get_engine_target()
ov = other._get_engine_target()

ridx: np.ndarray | None
lidx: np.ndarray | None

Expand All @@ -4278,36 +4293,34 @@ def _join_monotonic(self, other: Index, how="left"):
if how == "left":
join_index = self
lidx = None
ridx = self._left_indexer_unique(sv, ov)
ridx = self._left_indexer_unique(other)
elif how == "right":
join_index = other
lidx = self._left_indexer_unique(ov, sv)
lidx = other._left_indexer_unique(self)
ridx = None
elif how == "inner":
join_array, lidx, ridx = self._inner_indexer(sv, ov)
join_array, lidx, ridx = self._inner_indexer(other)
join_index = self._wrap_joined_index(join_array, other)
elif how == "outer":
join_array, lidx, ridx = self._outer_indexer(sv, ov)
join_array, lidx, ridx = self._outer_indexer(other)
join_index = self._wrap_joined_index(join_array, other)
else:
if how == "left":
join_array, lidx, ridx = self._left_indexer(sv, ov)
join_array, lidx, ridx = self._left_indexer(other)
elif how == "right":
join_array, ridx, lidx = self._left_indexer(ov, sv)
join_array, ridx, lidx = other._left_indexer(self)
elif how == "inner":
join_array, lidx, ridx = self._inner_indexer(sv, ov)
join_array, lidx, ridx = self._inner_indexer(other)
elif how == "outer":
join_array, lidx, ridx = self._outer_indexer(sv, ov)
join_array, lidx, ridx = self._outer_indexer(other)

join_index = self._wrap_joined_index(join_array, other)

lidx = None if lidx is None else ensure_platform_int(lidx)
ridx = None if ridx is None else ensure_platform_int(ridx)
return join_index, lidx, ridx

def _wrap_joined_index(
self: _IndexT, joined: np.ndarray, other: _IndexT
) -> _IndexT:
def _wrap_joined_index(self: _IndexT, joined: ArrayLike, other: _IndexT) -> _IndexT:
assert other.dtype == self.dtype

if isinstance(self, ABCMultiIndex):
Expand Down Expand Up @@ -4385,6 +4398,19 @@ def _get_engine_target(self) -> np.ndarray:
# ndarray]", expected "ndarray")
return self._values # type: ignore[return-value]

def _get_join_target(self) -> np.ndarray:
"""
Get the ndarray that we will pass to libjoin functions.
"""
return self._get_engine_target()

def _from_join_target(self, result: np.ndarray) -> ArrayLike:
"""
Cast the ndarray returned from one of the libjoin.foo_indexer functions
back to type(self)._data.
"""
return result

@doc(IndexOpsMixin._memory_usage)
def memory_usage(self, deep: bool = False) -> int:
result = self._memory_usage(deep=deep)
Expand Down
52 changes: 11 additions & 41 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
NaT,
Timedelta,
iNaT,
join as libjoin,
lib,
)
from pandas._libs.tslibs import (
Expand Down Expand Up @@ -75,36 +74,6 @@
_T = TypeVar("_T", bound="DatetimeIndexOpsMixin")


def _join_i8_wrapper(joinf, with_indexers: bool = True):
"""
Create the join wrapper methods.
"""

# error: 'staticmethod' used with a non-method
@staticmethod # type: ignore[misc]
def wrapper(left, right):
# Note: these only get called with left.dtype == right.dtype
orig_left = left

left = left.view("i8")
right = right.view("i8")

results = joinf(left, right)
if with_indexers:

join_index, left_indexer, right_indexer = results
if not isinstance(orig_left, np.ndarray):
# When called from Index._intersection/_union, we have the EA
join_index = join_index.view(orig_left._ndarray.dtype)
join_index = orig_left._from_backing_data(join_index)

return join_index, left_indexer, right_indexer

return results

return wrapper


@inherit_names(
["inferred_freq", "_resolution_obj", "resolution"],
DatetimeLikeArrayMixin,
Expand Down Expand Up @@ -603,13 +572,6 @@ def insert(self, loc: int, item):
# --------------------------------------------------------------------
# Join/Set Methods

_inner_indexer = _join_i8_wrapper(libjoin.inner_join_indexer)
_outer_indexer = _join_i8_wrapper(libjoin.outer_join_indexer)
_left_indexer = _join_i8_wrapper(libjoin.left_join_indexer)
_left_indexer_unique = _join_i8_wrapper(
libjoin.left_join_indexer_unique, with_indexers=False
)

def _get_join_freq(self, other):
"""
Get the freq to attach to the result of a join operation.
Expand All @@ -621,14 +583,22 @@ def _get_join_freq(self, other):
freq = self.freq if self._can_fast_union(other) else None
return freq

def _wrap_joined_index(self, joined: np.ndarray, other):
def _wrap_joined_index(self, joined, other):
assert other.dtype == self.dtype, (other.dtype, self.dtype)
assert joined.dtype == "i8" or joined.dtype == self.dtype, joined.dtype
joined = joined.view(self._data._ndarray.dtype)
result = super()._wrap_joined_index(joined, other)
result._data._freq = self._get_join_freq(other)
return result

def _get_join_target(self) -> np.ndarray:
return self._data._ndarray.view("i8")

def _from_join_target(self, result: np.ndarray):
# view e.g. i8 back to M8[ns]
result = result.view(self._data._ndarray.dtype)
return self._data._from_backing_data(result)

# --------------------------------------------------------------------

@doc(Index._convert_arr_indexer)
def _convert_arr_indexer(self, keyarr):
try:
Expand Down
20 changes: 17 additions & 3 deletions pandas/core/indexes/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import numpy as np

from pandas._typing import ArrayLike
from pandas.compat.numpy import function as nv
from pandas.errors import AbstractMethodError
from pandas.util._decorators import (
Expand Down Expand Up @@ -300,6 +301,11 @@ def searchsorted(self, value, side="left", sorter=None) -> np.ndarray:
def _get_engine_target(self) -> np.ndarray:
return np.asarray(self._data)

def _from_join_target(self, result: np.ndarray) -> ArrayLike:
# ATM this is only for IntervalIndex, implicit assumption
# about _get_engine_target
return type(self._data)._from_sequence(result, dtype=self.dtype)

def delete(self, loc):
"""
Make new Index with passed location(-s) deleted
Expand Down Expand Up @@ -410,6 +416,10 @@ def _simple_new(
def _get_engine_target(self) -> np.ndarray:
return self._data._ndarray

def _from_join_target(self, result: np.ndarray) -> ArrayLike:
assert result.dtype == self._data._ndarray.dtype
return self._data._from_backing_data(result)

def insert(self: _T, loc: int, item) -> Index:
"""
Make new Index inserting new item at location. Follows
Expand Down Expand Up @@ -458,7 +468,11 @@ def putmask(self, mask, value) -> Index:

return type(self)._simple_new(res_values, name=self.name)

def _wrap_joined_index(self: _T, joined: np.ndarray, other: _T) -> _T:
# error: Argument 1 of "_wrap_joined_index" is incompatible with supertype
# "Index"; supertype defines the argument type as "Union[ExtensionArray, ndarray]"
def _wrap_joined_index( # type: ignore[override]
self: _T, joined: NDArrayBackedExtensionArray, other: _T
) -> _T:
name = get_op_result_name(self, other)
arr = self._data._from_backing_data(joined)
return type(self)._simple_new(arr, name=name)

return type(self)._simple_new(joined, name=name)
6 changes: 2 additions & 4 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3613,14 +3613,12 @@ def _maybe_match_names(self, other):

def _intersection(self, other, sort=False) -> MultiIndex:
other, result_names = self._convert_can_do_setop(other)

lvals = self._values
rvals = other._values.astype(object, copy=False)
other = other.astype(object, copy=False)

uniq_tuples = None # flag whether _inner_indexer was successful
if self.is_monotonic and other.is_monotonic:
try:
inner_tuples = self._inner_indexer(lvals, rvals)[0]
inner_tuples = self._inner_indexer(other)[0]
sort = False # inner_tuples is already sorted
except TypeError:
pass
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/indexes/period/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TestJoin:
def test_join_outer_indexer(self):
pi = period_range("1/1/2000", "1/20/2000", freq="D")

result = pi._outer_indexer(pi._values, pi._values)
result = pi._outer_indexer(pi)
tm.assert_extension_array_equal(result[0], pi._values)
tm.assert_numpy_array_equal(result[1], np.arange(len(pi), dtype=np.intp))
tm.assert_numpy_array_equal(result[2], np.arange(len(pi), dtype=np.intp))
Expand Down