Skip to content

Commit 7e2aa42

Browse files
authored
BUG: name retention in Index.intersection (#38111)
1 parent aed5dba commit 7e2aa42

File tree

9 files changed

+75
-15
lines changed

9 files changed

+75
-15
lines changed

doc/source/whatsnew/v1.2.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@ Other
814814
- Fixed bug in metadata propagation incorrectly copying DataFrame columns as metadata when the column name overlaps with the metadata name (:issue:`37037`)
815815
- Fixed metadata propagation in the :class:`Series.dt`, :class:`Series.str` accessors, :class:`DataFrame.duplicated`, :class:`DataFrame.stack`, :class:`DataFrame.unstack`, :class:`DataFrame.pivot`, :class:`DataFrame.append`, :class:`DataFrame.diff`, :class:`DataFrame.applymap` and :class:`DataFrame.update` methods (:issue:`28283`, :issue:`37381`)
816816
- Fixed metadata propagation when selecting columns with ``DataFrame.__getitem__`` (:issue:`28283`)
817+
- Bug in :meth:`Index.intersection` with non-:class:`Index` failing to set the correct name on the returned :class:`Index` (:issue:`38111`)
817818
- Bug in :meth:`Index.union` behaving differently depending on whether operand is an :class:`Index` or other list-like (:issue:`36384`)
818819
- Bug in :meth:`Index.intersection` with non-matching numeric dtypes casting to ``object`` dtype instead of minimal common dtype (:issue:`38122`)
819820
- Passing an array with 2 or more dimensions to the :class:`Series` constructor now raises the more specific ``ValueError`` rather than a bare ``Exception`` (:issue:`35744`)

pandas/core/indexes/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2821,7 +2821,7 @@ def intersection(self, other, sort=False):
28212821
"""
28222822
self._validate_sort_keyword(sort)
28232823
self._assert_can_do_setop(other)
2824-
other = ensure_index(other)
2824+
other, _ = self._convert_can_do_setop(other)
28252825

28262826
if self.equals(other) and not self.has_duplicates:
28272827
return self._get_reconciled_name_object(other)

pandas/core/indexes/datetimelike.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -686,10 +686,17 @@ def intersection(self, other, sort=False):
686686
"""
687687
self._validate_sort_keyword(sort)
688688
self._assert_can_do_setop(other)
689+
other, _ = self._convert_can_do_setop(other)
689690

690691
if self.equals(other):
691692
return self._get_reconciled_name_object(other)
692693

694+
return self._intersection(other, sort=sort)
695+
696+
def _intersection(self, other: Index, sort=False) -> Index:
697+
"""
698+
intersection specialized to the case with matching dtypes.
699+
"""
693700
if len(self) == 0:
694701
return self.copy()._get_reconciled_name_object(other)
695702
if len(other) == 0:
@@ -704,10 +711,11 @@ def intersection(self, other, sort=False):
704711
return result
705712

706713
elif not self._can_fast_intersect(other):
707-
result = Index.intersection(self, other, sort=sort)
708-
# We need to invalidate the freq because Index.intersection
714+
result = Index._intersection(self, other, sort=sort)
715+
# We need to invalidate the freq because Index._intersection
709716
# uses _shallow_copy on a view of self._data, which will preserve
710717
# self.freq if we're not careful.
718+
result = self._wrap_setop_result(other, result)
711719
return result._with_freq(None)._with_freq("infer")
712720

713721
# to make our life easier, "sort" the two ranges

pandas/core/indexes/interval.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,11 @@ def setop_check(method):
124124
def wrapped(self, other, sort=False):
125125
self._validate_sort_keyword(sort)
126126
self._assert_can_do_setop(other)
127-
other = ensure_index(other)
127+
other, _ = self._convert_can_do_setop(other)
128+
129+
if op_name == "intersection":
130+
if self.equals(other):
131+
return self._get_reconciled_name_object(other)
128132

129133
if not isinstance(other, IntervalIndex):
130134
result = getattr(self.astype(object), op_name)(other)

pandas/core/indexes/multi.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3603,7 +3603,12 @@ def intersection(self, other, sort=False):
36033603
if self.equals(other):
36043604
if self.has_duplicates:
36053605
return self.unique().rename(result_names)
3606-
return self.rename(result_names)
3606+
return self._get_reconciled_name_object(other)
3607+
3608+
return self._intersection(other, sort=sort)
3609+
3610+
def _intersection(self, other, sort=False):
3611+
other, result_names = self._convert_can_do_setop(other)
36073612

36083613
if not is_object_dtype(other.dtype):
36093614
# The intersection is empty
@@ -3721,7 +3726,7 @@ def _convert_can_do_setop(self, other):
37213726
else:
37223727
msg = "other must be a MultiIndex or a list of tuples"
37233728
try:
3724-
other = MultiIndex.from_tuples(other)
3729+
other = MultiIndex.from_tuples(other, names=self.names)
37253730
except (ValueError, TypeError) as err:
37263731
# ValueError raised by tuples_to_object_array if we
37273732
# have non-object dtype

pandas/core/indexes/period.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -639,15 +639,19 @@ def _setop(self, other, sort, opname: str):
639639
def intersection(self, other, sort=False):
640640
self._validate_sort_keyword(sort)
641641
self._assert_can_do_setop(other)
642-
other = ensure_index(other)
642+
other, _ = self._convert_can_do_setop(other)
643643

644644
if self.equals(other):
645645
return self._get_reconciled_name_object(other)
646646

647-
elif is_object_dtype(other.dtype):
647+
return self._intersection(other, sort=sort)
648+
649+
def _intersection(self, other, sort=False):
650+
651+
if is_object_dtype(other.dtype):
648652
return self.astype("O").intersection(other, sort=sort)
649653

650-
elif not is_dtype_equal(self.dtype, other.dtype):
654+
elif not self._is_comparable_dtype(other.dtype):
651655
# We can infer that the intersection is empty.
652656
# assert_can_do_setop ensures that this is not just a mismatched freq
653657
this = self[:0].astype("O")

pandas/core/indexes/range.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pandas.core.dtypes.common import (
1616
ensure_platform_int,
1717
ensure_python_int,
18+
is_dtype_equal,
1819
is_float,
1920
is_integer,
2021
is_list_like,
@@ -504,11 +505,21 @@ def intersection(self, other, sort=False):
504505
intersection : Index
505506
"""
506507
self._validate_sort_keyword(sort)
508+
self._assert_can_do_setop(other)
509+
other, _ = self._convert_can_do_setop(other)
507510

508511
if self.equals(other):
509512
return self._get_reconciled_name_object(other)
510513

514+
return self._intersection(other, sort=sort)
515+
516+
def _intersection(self, other, sort=False):
517+
511518
if not isinstance(other, RangeIndex):
519+
if is_dtype_equal(other.dtype, self.dtype):
520+
# Int64Index
521+
result = super()._intersection(other, sort=sort)
522+
return self._wrap_setop_result(other, result)
512523
return super().intersection(other, sort=sort)
513524

514525
if not len(self) or not len(other):

pandas/tests/indexes/datetimes/test_setops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,10 +471,11 @@ def test_intersection_bug(self):
471471

472472
def test_intersection_list(self):
473473
# GH#35876
474+
# values is not an Index -> no name -> retain "a"
474475
values = [pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01")]
475476
idx = DatetimeIndex(values, name="a")
476477
res = idx.intersection(values)
477-
tm.assert_index_equal(res, idx.rename(None))
478+
tm.assert_index_equal(res, idx)
478479

479480
def test_month_range_union_tz_pytz(self, sort):
480481
from pytz import timezone

pandas/tests/indexes/test_setops.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,20 @@ def test_compatible_inconsistent_pairs(idx_fact1, idx_fact2):
9898
("Period[D]", "float64", "object"),
9999
],
100100
)
101-
def test_union_dtypes(left, right, expected):
101+
@pytest.mark.parametrize("names", [("foo", "foo", "foo"), ("foo", "bar", None)])
102+
def test_union_dtypes(left, right, expected, names):
102103
left = pandas_dtype(left)
103104
right = pandas_dtype(right)
104-
a = pd.Index([], dtype=left)
105-
b = pd.Index([], dtype=right)
106-
result = a.union(b).dtype
107-
assert result == expected
105+
a = pd.Index([], dtype=left, name=names[0])
106+
b = pd.Index([], dtype=right, name=names[1])
107+
result = a.union(b)
108+
assert result.dtype == expected
109+
assert result.name == names[2]
110+
111+
# Testing name retention
112+
# TODO: pin down desired dtype; do we want it to be commutative?
113+
result = a.intersection(b)
114+
assert result.name == names[2]
108115

109116

110117
def test_dunder_inplace_setops_deprecated(index):
@@ -388,6 +395,25 @@ def test_intersect_unequal(self, index, fname, sname, expected_name):
388395
expected = index[1:].set_names(expected_name).sort_values()
389396
tm.assert_index_equal(intersect, expected)
390397

398+
def test_intersection_name_retention_with_nameless(self, index):
399+
if isinstance(index, MultiIndex):
400+
index = index.rename(list(range(index.nlevels)))
401+
else:
402+
index = index.rename("foo")
403+
404+
other = np.asarray(index)
405+
406+
result = index.intersection(other)
407+
assert result.name == index.name
408+
409+
# empty other, same dtype
410+
result = index.intersection(other[:0])
411+
assert result.name == index.name
412+
413+
# empty `self`
414+
result = index[:0].intersection(other)
415+
assert result.name == index.name
416+
391417
def test_difference_preserves_type_empty(self, index, sort):
392418
# GH#20040
393419
# If taking difference of a set and itself, it

0 commit comments

Comments
 (0)