Skip to content

Commit cd49372

Browse files
authored
REF: de-duplicate algos.try_sort calls (#39330)
1 parent c322b24 commit cd49372

File tree

3 files changed

+21
-24
lines changed

3 files changed

+21
-24
lines changed

pandas/core/indexes/base.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2882,15 +2882,7 @@ def _union(self, other, sort):
28822882
else:
28832883
result = lvals
28842884

2885-
if sort is None:
2886-
try:
2887-
result = algos.safe_sort(result)
2888-
except TypeError as err:
2889-
warnings.warn(
2890-
f"{err}, sort order is undefined for incomparable objects",
2891-
RuntimeWarning,
2892-
stacklevel=3,
2893-
)
2885+
result = _maybe_try_sort(result, sort)
28942886

28952887
return result
28962888

@@ -2998,9 +2990,7 @@ def _intersection(self, other, sort=False):
29982990
indexer = indexer.take(mask.nonzero()[0])
29992991

30002992
result = other.take(indexer).unique()._values
3001-
3002-
if sort is None:
3003-
result = algos.safe_sort(result)
2993+
result = _maybe_try_sort(result, sort)
30042994

30052995
# Intersection has to be unique
30062996
assert Index(result).is_unique
@@ -3070,11 +3060,7 @@ def _difference(self, other, sort):
30703060

30713061
label_diff = np.setdiff1d(np.arange(this.size), indexer, assume_unique=True)
30723062
the_diff = this._values.take(label_diff)
3073-
if sort is None:
3074-
try:
3075-
the_diff = algos.safe_sort(the_diff)
3076-
except TypeError:
3077-
pass
3063+
the_diff = _maybe_try_sort(the_diff, sort)
30783064

30793065
return the_diff
30803066

@@ -3155,11 +3141,7 @@ def symmetric_difference(self, other, result_name=None, sort=None):
31553141
right_diff = other._values.take(right_indexer)
31563142

31573143
the_diff = concat_compat([left_diff, right_diff])
3158-
if sort is None:
3159-
try:
3160-
the_diff = algos.safe_sort(the_diff)
3161-
except TypeError:
3162-
pass
3144+
the_diff = _maybe_try_sort(the_diff, sort)
31633145

31643146
return Index(the_diff, name=result_name)
31653147

@@ -6354,3 +6336,16 @@ def unpack_nested_dtype(other: Index) -> Index:
63546336
# here too.
63556337
return dtype.categories
63566338
return other
6339+
6340+
6341+
def _maybe_try_sort(result, sort):
6342+
if sort is None:
6343+
try:
6344+
result = algos.safe_sort(result)
6345+
except TypeError as err:
6346+
warnings.warn(
6347+
f"{err}, sort order is undefined for incomparable objects",
6348+
RuntimeWarning,
6349+
stacklevel=4,
6350+
)
6351+
return result

pandas/tests/indexes/interval/test_setops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def test_symmetric_difference(self, closed, sort):
162162
expected = empty_index(dtype="float64", closed=closed)
163163
tm.assert_index_equal(result, expected)
164164

165+
@pytest.mark.filterwarnings("ignore:'<' not supported between:RuntimeWarning")
165166
@pytest.mark.parametrize(
166167
"op_name", ["union", "intersection", "difference", "symmetric_difference"]
167168
)

pandas/tests/indexes/test_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -925,8 +925,9 @@ def test_difference_incomparable(self, opname):
925925
b = Index([2, Timestamp("1999"), 1])
926926
op = operator.methodcaller(opname, b)
927927

928-
# sort=None, the default
929-
result = op(a)
928+
with tm.assert_produces_warning(RuntimeWarning):
929+
# sort=None, the default
930+
result = op(a)
930931
expected = Index([3, Timestamp("2000"), 2, Timestamp("1999")])
931932
if opname == "difference":
932933
expected = expected[:2]

0 commit comments

Comments
 (0)