|
3 | 3 | timedelta,
|
4 | 4 | )
|
5 | 5 |
|
| 6 | +from hypothesis import ( |
| 7 | + assume, |
| 8 | + given, |
| 9 | + strategies as st, |
| 10 | +) |
6 | 11 | import numpy as np
|
7 | 12 | import pytest
|
8 | 13 |
|
@@ -359,11 +364,44 @@ def test_difference_mismatched_step(self):
|
359 | 364 | obj = RangeIndex.from_range(range(1, 10), name="foo")
|
360 | 365 |
|
361 | 366 | result = obj.difference(obj[::2])
|
362 |
| - expected = Int64Index(obj[1::2]._values, name=obj.name) |
| 367 | + expected = obj[1::2] |
363 | 368 | tm.assert_index_equal(result, expected, exact=True)
|
364 | 369 |
|
| 370 | + result = obj[::-1].difference(obj[::2], sort=False) |
| 371 | + tm.assert_index_equal(result, expected[::-1], exact=True) |
| 372 | + |
365 | 373 | result = obj.difference(obj[1::2])
|
366 |
| - expected = Int64Index(obj[::2]._values, name=obj.name) |
| 374 | + expected = obj[::2] |
| 375 | + tm.assert_index_equal(result, expected, exact=True) |
| 376 | + |
| 377 | + result = obj[::-1].difference(obj[1::2], sort=False) |
| 378 | + tm.assert_index_equal(result, expected[::-1], exact=True) |
| 379 | + |
| 380 | + def test_difference_interior_non_preserving(self): |
| 381 | + # case with intersection of length 1 but RangeIndex is not preserved |
| 382 | + idx = Index(range(10)) |
| 383 | + |
| 384 | + other = idx[3:4] |
| 385 | + result = idx.difference(other) |
| 386 | + expected = Int64Index([0, 1, 2, 4, 5, 6, 7, 8, 9]) |
| 387 | + tm.assert_index_equal(result, expected, exact=True) |
| 388 | + |
| 389 | + # case with other.step / self.step > 2 |
| 390 | + other = idx[::3] |
| 391 | + result = idx.difference(other) |
| 392 | + expected = Int64Index([1, 2, 4, 5, 7, 8]) |
| 393 | + tm.assert_index_equal(result, expected, exact=True) |
| 394 | + |
| 395 | + # cases with only reaching one end of left |
| 396 | + obj = Index(range(20)) |
| 397 | + other = obj[:10:2] |
| 398 | + result = obj.difference(other) |
| 399 | + expected = Int64Index([1, 3, 5, 7, 9] + list(range(10, 20))) |
| 400 | + tm.assert_index_equal(result, expected, exact=True) |
| 401 | + |
| 402 | + other = obj[1:11:2] |
| 403 | + result = obj.difference(other) |
| 404 | + expected = Int64Index([0, 2, 4, 6, 8, 10] + list(range(11, 20))) |
367 | 405 | tm.assert_index_equal(result, expected, exact=True)
|
368 | 406 |
|
369 | 407 | def test_symmetric_difference(self):
|
@@ -391,3 +429,44 @@ def test_symmetric_difference(self):
|
391 | 429 | result = left.symmetric_difference(right[1:])
|
392 | 430 | expected = Int64Index([1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14])
|
393 | 431 | tm.assert_index_equal(result, expected)
|
| 432 | + |
| 433 | + |
| 434 | +def assert_range_or_not_is_rangelike(index): |
| 435 | + """ |
| 436 | + Check that we either have a RangeIndex or that this index *cannot* |
| 437 | + be represented as a RangeIndex. |
| 438 | + """ |
| 439 | + if not isinstance(index, RangeIndex) and len(index) > 0: |
| 440 | + diff = index[:-1] - index[1:] |
| 441 | + assert not (diff == diff[0]).all() |
| 442 | + |
| 443 | + |
| 444 | +@given( |
| 445 | + st.integers(-20, 20), |
| 446 | + st.integers(-20, 20), |
| 447 | + st.integers(-20, 20), |
| 448 | + st.integers(-20, 20), |
| 449 | + st.integers(-20, 20), |
| 450 | + st.integers(-20, 20), |
| 451 | +) |
| 452 | +def test_range_difference(start1, stop1, step1, start2, stop2, step2): |
| 453 | + # test that |
| 454 | + # a) we match Int64Index.difference and |
| 455 | + # b) we return RangeIndex whenever it is possible to do so. |
| 456 | + assume(step1 != 0) |
| 457 | + assume(step2 != 0) |
| 458 | + |
| 459 | + left = RangeIndex(start1, stop1, step1) |
| 460 | + right = RangeIndex(start2, stop2, step2) |
| 461 | + |
| 462 | + result = left.difference(right, sort=None) |
| 463 | + assert_range_or_not_is_rangelike(result) |
| 464 | + |
| 465 | + alt = Int64Index(left).difference(Int64Index(right), sort=None) |
| 466 | + tm.assert_index_equal(result, alt, exact="equiv") |
| 467 | + |
| 468 | + result = left.difference(right, sort=False) |
| 469 | + assert_range_or_not_is_rangelike(result) |
| 470 | + |
| 471 | + alt = Int64Index(left).difference(Int64Index(right), sort=False) |
| 472 | + tm.assert_index_equal(result, alt, exact="equiv") |
0 commit comments