Skip to content

Backport PR #36927: BUG: Fix duplicates in intersection of multiindexes #38155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Fixed regressions
- Fixed regression in :meth:`DataFrame.groupby` aggregation with out-of-bounds datetime objects in an object-dtype column (:issue:`36003`)
- Fixed regression in ``df.groupby(..).rolling(..)`` with the resulting :class:`MultiIndex` when grouping by a label that is in the index (:issue:`37641`)
- Fixed regression in :meth:`DataFrame.fillna` not filling ``NaN`` after other operations such as :meth:`DataFrame.pivot` (:issue:`36495`).
- Fixed regression in :meth:`MultiIndex.intersection` returning duplicates when at least one of the indexes had duplicates (:issue:`36915`)

.. ---------------------------------------------------------------------------

Expand Down
9 changes: 6 additions & 3 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2654,7 +2654,7 @@ def intersection(self, other, sort=False):
self._assert_can_do_setop(other)
other = ensure_index(other)

if self.equals(other):
if self.equals(other) and not self.has_duplicates:
return self._get_reconciled_name_object(other)

if not is_dtype_equal(self.dtype, other.dtype):
Expand All @@ -2672,7 +2672,7 @@ def intersection(self, other, sort=False):
except TypeError:
pass
else:
return self._wrap_setop_result(other, result)
return self._wrap_setop_result(other, algos.unique1d(result))

try:
indexer = Index(rvals).get_indexer(lvals)
Expand All @@ -2683,13 +2683,16 @@ def intersection(self, other, sort=False):
indexer = algos.unique1d(Index(rvals).get_indexer_non_unique(lvals)[0])
indexer = indexer[indexer != -1]

taken = other.take(indexer)
taken = other.take(indexer).unique()
res_name = get_op_result_name(self, other)

if sort is None:
taken = algos.safe_sort(taken.values)
return self._shallow_copy(taken, name=res_name)

# Intersection has to be unique
assert algos.unique(taken._values).shape == taken._values.shape

taken.name = res_name
return taken

Expand Down
8 changes: 6 additions & 2 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3398,6 +3398,8 @@ def intersection(self, other, sort=False):
other, result_names = self._convert_can_do_setop(other)

if self.equals(other):
if self.has_duplicates:
return self.unique()
return self

if not is_object_dtype(other.dtype):
Expand All @@ -3416,10 +3418,12 @@ def intersection(self, other, sort=False):
uniq_tuples = None # flag whether _inner_indexer was successful
if self.is_monotonic and other.is_monotonic:
try:
uniq_tuples = self._inner_indexer(lvals, rvals)[0]
sort = False # uniq_tuples is already sorted
inner_tuples = self._inner_indexer(lvals, rvals)[0]
sort = False # inner_tuples is already sorted
except TypeError:
pass
else:
uniq_tuples = algos.unique(inner_tuples)

if uniq_tuples is None:
other_uniq = set(rvals)
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,11 @@ def _should_reindex_frame_op(
if fill_value is None and level is None and axis is default_axis:
# TODO: any other cases we should handle here?
cols = left.columns.intersection(right.columns)
if not (cols.equals(left.columns) and cols.equals(right.columns)):

# Intersection is always unique so we have to check the unique columns
left_uniques = left.columns.unique()
right_uniques = right.columns.unique()
if not (cols.equals(left_uniques) and cols.equals(right_uniques)):
return True

return False
Expand Down
9 changes: 7 additions & 2 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,9 @@ def _validate_specification(self):
raise MergeError("Must pass left_on or left_index=True")
else:
# use the common columns
common_cols = self.left.columns.intersection(self.right.columns)
left_cols = self.left.columns
right_cols = self.right.columns
common_cols = left_cols.intersection(right_cols)
if len(common_cols) == 0:
raise MergeError(
"No common columns to perform merge on. "
Expand All @@ -1218,7 +1220,10 @@ def _validate_specification(self):
f"left_index={self.left_index}, "
f"right_index={self.right_index}"
)
if not common_cols.is_unique:
if (
not left_cols.join(common_cols, how="inner").is_unique
or not right_cols.join(common_cols, how="inner").is_unique
):
raise MergeError(f"Data columns not unique: {repr(common_cols)}")
self.left_on = self.right_on = common_cols
elif self.on is not None:
Expand Down
23 changes: 23 additions & 0 deletions pandas/tests/indexes/multi/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,26 @@ def test_setops_disallow_true(method):

with pytest.raises(ValueError, match="The 'sort' keyword only takes"):
getattr(idx1, method)(idx2, sort=True)


@pytest.mark.parametrize(
("tuples", "exp_tuples"),
[
([("val1", "test1")], [("val1", "test1")]),
([("val1", "test1"), ("val1", "test1")], [("val1", "test1")]),
(
[("val2", "test2"), ("val1", "test1")],
[("val2", "test2"), ("val1", "test1")],
),
],
)
def test_intersect_with_duplicates(tuples, exp_tuples):
# GH#36915
left = MultiIndex.from_tuples(tuples, names=["first", "second"])
right = MultiIndex.from_tuples(
[("val1", "test1"), ("val1", "test1"), ("val2", "test2")],
names=["first", "second"],
)
result = left.intersection(right)
expected = MultiIndex.from_tuples(exp_tuples, names=["first", "second"])
tm.assert_index_equal(result, expected)
2 changes: 1 addition & 1 deletion pandas/tests/indexes/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def test_intersection_monotonic(self, index2, keeps_name, sort):

@pytest.mark.parametrize(
"index2,expected_arr",
[(Index(["B", "D"]), ["B"]), (Index(["B", "D", "A"]), ["A", "B", "A"])],
[(Index(["B", "D"]), ["B"]), (Index(["B", "D", "A"]), ["A", "B"])],
)
def test_intersection_non_monotonic_non_unique(self, index2, expected_arr, sort):
# non-monotonic non-unique
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/indexes/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,13 @@ def test_union_dtypes(left, right, expected):
b = pd.Index([], dtype=right)
result = (a | b).dtype
assert result == expected


@pytest.mark.parametrize("values", [[1, 2, 2, 3], [3, 3]])
def test_intersection_duplicates(values):
# GH#31326
a = pd.Index(values)
b = pd.Index([3, 3])
result = a.intersection(b)
expected = pd.Index([3])
tm.assert_index_equal(result, expected)
2 changes: 1 addition & 1 deletion pandas/tests/reshape/merge/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def test_overlapping_columns_error_message(self):

# #2649, #10639
df2.columns = ["key1", "foo", "foo"]
msg = r"Data columns not unique: Index\(\['foo', 'foo'\], dtype='object'\)"
msg = r"Data columns not unique: Index\(\['foo'\], dtype='object'\)"
with pytest.raises(MergeError, match=msg):
merge(df, df2)

Expand Down