Skip to content

Commit 01cc1ee

Browse files
authored
ENH: retain RangeIndex in RangeIndex.difference (#44093)
1 parent 3a6d4cd commit 01cc1ee

File tree

2 files changed

+128
-22
lines changed

2 files changed

+128
-22
lines changed

pandas/core/indexes/range.py

+47-20
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,9 @@ def _difference(self, other, sort=None):
675675
if not isinstance(other, RangeIndex):
676676
return super()._difference(other, sort=sort)
677677

678+
if sort is None and self.step < 0:
679+
return self[::-1]._difference(other)
680+
678681
res_name = ops.get_op_result_name(self, other)
679682

680683
first = self._range[::-1] if self.step < 0 else self._range
@@ -683,36 +686,60 @@ def _difference(self, other, sort=None):
683686
overlap = overlap[::-1]
684687

685688
if len(overlap) == 0:
686-
result = self.rename(name=res_name)
687-
if sort is None and self.step < 0:
688-
result = result[::-1]
689-
return result
689+
return self.rename(name=res_name)
690690
if len(overlap) == len(self):
691691
return self[:0].rename(res_name)
692-
if not isinstance(overlap, RangeIndex):
693-
# We won't end up with RangeIndex, so fall back
694-
return super()._difference(other, sort=sort)
695-
if overlap.step != first.step:
696-
# In some cases we might be able to get a RangeIndex back,
697-
# but not worth the effort.
698-
return super()._difference(other, sort=sort)
699692

700-
if overlap[0] == first.start:
701-
# The difference is everything after the intersection
702-
new_rng = range(overlap[-1] + first.step, first.stop, first.step)
703-
elif overlap[-1] == first[-1]:
704-
# The difference is everything before the intersection
705-
new_rng = range(first.start, overlap[0], first.step)
693+
# overlap.step will always be a multiple of self.step (see _intersection)
694+
695+
if len(overlap) == 1:
696+
if overlap[0] == self[0]:
697+
return self[1:]
698+
699+
elif overlap[0] == self[-1]:
700+
return self[:-1]
701+
702+
elif len(self) == 3 and overlap[0] == self[1]:
703+
return self[::2]
704+
705+
else:
706+
return super()._difference(other, sort=sort)
707+
708+
if overlap.step == first.step:
709+
if overlap[0] == first.start:
710+
# The difference is everything after the intersection
711+
new_rng = range(overlap[-1] + first.step, first.stop, first.step)
712+
elif overlap[-1] == first[-1]:
713+
# The difference is everything before the intersection
714+
new_rng = range(first.start, overlap[0], first.step)
715+
else:
716+
# The difference is not range-like
717+
# e.g. range(1, 10, 1) and range(3, 7, 1)
718+
return super()._difference(other, sort=sort)
719+
706720
else:
707-
# The difference is not range-like
721+
# We must have len(self) > 1, bc we ruled out above
722+
# len(overlap) == 0 and len(overlap) == len(self)
723+
assert len(self) > 1
724+
725+
if overlap.step == first.step * 2:
726+
if overlap[0] == first[0] and overlap[-1] in (first[-1], first[-2]):
727+
# e.g. range(1, 10, 1) and range(1, 10, 2)
728+
return self[1::2]
729+
730+
elif overlap[0] == first[1] and overlap[-1] in (first[-1], first[-2]):
731+
# e.g. range(1, 10, 1) and range(2, 10, 2)
732+
return self[::2]
733+
734+
# We can get here with e.g. range(20) and range(0, 10, 2)
735+
736+
# e.g. range(10) and range(0, 10, 3)
708737
return super()._difference(other, sort=sort)
709738

710739
new_index = type(self)._simple_new(new_rng, name=res_name)
711740
if first is not self._range:
712741
new_index = new_index[::-1]
713742

714-
if sort is None and new_index.step < 0:
715-
new_index = new_index[::-1]
716743
return new_index
717744

718745
def symmetric_difference(self, other, result_name: Hashable = None, sort=None):

pandas/tests/indexes/ranges/test_setops.py

+81-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
timedelta,
44
)
55

6+
from hypothesis import (
7+
assume,
8+
given,
9+
strategies as st,
10+
)
611
import numpy as np
712
import pytest
813

@@ -359,11 +364,44 @@ def test_difference_mismatched_step(self):
359364
obj = RangeIndex.from_range(range(1, 10), name="foo")
360365

361366
result = obj.difference(obj[::2])
362-
expected = Int64Index(obj[1::2]._values, name=obj.name)
367+
expected = obj[1::2]
363368
tm.assert_index_equal(result, expected, exact=True)
364369

370+
result = obj[::-1].difference(obj[::2], sort=False)
371+
tm.assert_index_equal(result, expected[::-1], exact=True)
372+
365373
result = obj.difference(obj[1::2])
366-
expected = Int64Index(obj[::2]._values, name=obj.name)
374+
expected = obj[::2]
375+
tm.assert_index_equal(result, expected, exact=True)
376+
377+
result = obj[::-1].difference(obj[1::2], sort=False)
378+
tm.assert_index_equal(result, expected[::-1], exact=True)
379+
380+
def test_difference_interior_non_preserving(self):
381+
# case with intersection of length 1 but RangeIndex is not preserved
382+
idx = Index(range(10))
383+
384+
other = idx[3:4]
385+
result = idx.difference(other)
386+
expected = Int64Index([0, 1, 2, 4, 5, 6, 7, 8, 9])
387+
tm.assert_index_equal(result, expected, exact=True)
388+
389+
# case with other.step / self.step > 2
390+
other = idx[::3]
391+
result = idx.difference(other)
392+
expected = Int64Index([1, 2, 4, 5, 7, 8])
393+
tm.assert_index_equal(result, expected, exact=True)
394+
395+
# cases with only reaching one end of left
396+
obj = Index(range(20))
397+
other = obj[:10:2]
398+
result = obj.difference(other)
399+
expected = Int64Index([1, 3, 5, 7, 9] + list(range(10, 20)))
400+
tm.assert_index_equal(result, expected, exact=True)
401+
402+
other = obj[1:11:2]
403+
result = obj.difference(other)
404+
expected = Int64Index([0, 2, 4, 6, 8, 10] + list(range(11, 20)))
367405
tm.assert_index_equal(result, expected, exact=True)
368406

369407
def test_symmetric_difference(self):
@@ -391,3 +429,44 @@ def test_symmetric_difference(self):
391429
result = left.symmetric_difference(right[1:])
392430
expected = Int64Index([1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14])
393431
tm.assert_index_equal(result, expected)
432+
433+
434+
def assert_range_or_not_is_rangelike(index):
435+
"""
436+
Check that we either have a RangeIndex or that this index *cannot*
437+
be represented as a RangeIndex.
438+
"""
439+
if not isinstance(index, RangeIndex) and len(index) > 0:
440+
diff = index[:-1] - index[1:]
441+
assert not (diff == diff[0]).all()
442+
443+
444+
@given(
445+
st.integers(-20, 20),
446+
st.integers(-20, 20),
447+
st.integers(-20, 20),
448+
st.integers(-20, 20),
449+
st.integers(-20, 20),
450+
st.integers(-20, 20),
451+
)
452+
def test_range_difference(start1, stop1, step1, start2, stop2, step2):
453+
# test that
454+
# a) we match Int64Index.difference and
455+
# b) we return RangeIndex whenever it is possible to do so.
456+
assume(step1 != 0)
457+
assume(step2 != 0)
458+
459+
left = RangeIndex(start1, stop1, step1)
460+
right = RangeIndex(start2, stop2, step2)
461+
462+
result = left.difference(right, sort=None)
463+
assert_range_or_not_is_rangelike(result)
464+
465+
alt = Int64Index(left).difference(Int64Index(right), sort=None)
466+
tm.assert_index_equal(result, alt, exact="equiv")
467+
468+
result = left.difference(right, sort=False)
469+
assert_range_or_not_is_rangelike(result)
470+
471+
alt = Int64Index(left).difference(Int64Index(right), sort=False)
472+
tm.assert_index_equal(result, alt, exact="equiv")

0 commit comments

Comments
 (0)