Skip to content

ENH: cast instead of raise for IntervalIndex setops with differnet closed #39267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ Interval
^^^^^^^^
- Bug in :meth:`IntervalIndex.intersection` and :meth:`IntervalIndex.symmetric_difference` always returning object-dtype when operating with :class:`CategoricalIndex` (:issue:`38653`, :issue:`38741`)
- Bug in :meth:`IntervalIndex.intersection` returning duplicates when at least one of both Indexes has duplicates which are present in the other (:issue:`38743`)
-
- :meth:`IntervalIndex.union`, :meth:`IntervalIndex.intersection`, :meth:`IntervalIndex.difference`, and :meth:`IntervalIndex.symmetric_difference` now cast to the appropriate dtype instead of raising ``TypeError`` when operating with another :class:`IntervalIndex` with incompatible dtype (:issue:`39267`)

Indexing
^^^^^^^^
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3003,6 +3003,7 @@ def _intersection(self, other, sort=False):

return result

@final
def difference(self, other, sort=None):
"""
Return a new Index with elements of index not in `other`.
Expand Down Expand Up @@ -3127,7 +3128,7 @@ def symmetric_difference(self, other, result_name=None, sort=None):
result_name = result_name_update

if not self._should_compare(other):
return self.union(other).rename(result_name)
return self.union(other, sort=sort).rename(result_name)
elif not is_dtype_equal(self.dtype, other.dtype):
dtype = find_common_type([self.dtype, other.dtype])
this = self.astype(dtype, copy=False)
Expand Down Expand Up @@ -6236,7 +6237,7 @@ def _maybe_cast_data_without_dtype(subarr):
try:
data = IntervalArray._from_sequence(subarr, copy=False)
return data
except ValueError:
except (ValueError, TypeError):
# GH27172: mixed closed Intervals --> object dtype
pass
elif inferred == "boolean":
Expand Down
5 changes: 2 additions & 3 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,9 +660,8 @@ def is_type_compatible(self, kind: str) -> bool:
# --------------------------------------------------------------------
# Set Operation Methods

@Appender(Index.difference.__doc__)
def difference(self, other, sort=None):
new_idx = super().difference(other, sort=sort)._with_freq(None)
def _difference(self, other, sort=None):
new_idx = super()._difference(other, sort=sort)._with_freq(None)
return new_idx

def _intersection(self, other: Index, sort=False) -> Index:
Expand Down
29 changes: 11 additions & 18 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,17 @@ def setop_check(method):
def wrapped(self, other, sort=False):
self._validate_sort_keyword(sort)
self._assert_can_do_setop(other)
other, _ = self._convert_can_do_setop(other)
other, result_name = self._convert_can_do_setop(other)

if not isinstance(other, IntervalIndex):
result = getattr(self.astype(object), op_name)(other)
if op_name in ("difference",):
result = result.astype(self.dtype)
return result
if op_name == "difference":
if not isinstance(other, IntervalIndex):
result = getattr(self.astype(object), op_name)(other, sort=sort)
return result.astype(self.dtype)

elif not self._should_compare(other):
# GH#19016: ensure set op will not return a prohibited dtype
result = getattr(self.astype(object), op_name)(other, sort=sort)
return result.astype(self.dtype)

return method(self, other, sort)

Expand Down Expand Up @@ -912,17 +916,6 @@ def _format_space(self) -> str:
# --------------------------------------------------------------------
# Set Operations

def _assert_can_do_setop(self, other):
super()._assert_can_do_setop(other)

if isinstance(other, IntervalIndex) and not self._should_compare(other):
# GH#19016: ensure set op will not return a prohibited dtype
raise TypeError(
"can only do set operations between two IntervalIndex "
"objects that are closed on the same side "
"and have compatible dtypes"
)

def _intersection(self, other, sort):
"""
intersection specialized to the case with matching dtypes.
Expand Down Expand Up @@ -1014,7 +1007,7 @@ def func(self, other, sort=sort):
return setop_check(func)

_union = _setop("union")
difference = _setop("difference")
_difference = _setop("difference")

# --------------------------------------------------------------------

Expand Down
10 changes: 5 additions & 5 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,14 +632,14 @@ def _union(self, other, sort):
return type(self)(start_r, end_r + step_o, step_o)
return self._int64index._union(other, sort=sort)

def difference(self, other, sort=None):
def _difference(self, other, sort=None):
# optimized set operation if we have another RangeIndex
self._validate_sort_keyword(sort)
self._assert_can_do_setop(other)
other, result_name = self._convert_can_do_setop(other)

if not isinstance(other, RangeIndex):
return super().difference(other, sort=sort)
return super()._difference(other, sort=sort)

res_name = ops.get_op_result_name(self, other)

Expand All @@ -654,11 +654,11 @@ def difference(self, other, sort=None):
return self[:0].rename(res_name)
if not isinstance(overlap, RangeIndex):
# We won't end up with RangeIndex, so fall back
return super().difference(other, sort=sort)
return super()._difference(other, sort=sort)
if overlap.step != first.step:
# In some cases we might be able to get a RangeIndex back,
# but not worth the effort.
return super().difference(other, sort=sort)
return super()._difference(other, sort=sort)

if overlap[0] == first.start:
# The difference is everything after the intersection
Expand All @@ -668,7 +668,7 @@ def difference(self, other, sort=None):
new_rng = range(first.start, overlap[0], first.step)
else:
# The difference is not range-like
return super().difference(other, sort=sort)
return super()._difference(other, sort=sort)

new_index = type(self)._simple_new(new_rng, name=res_name)
if first is not self._range:
Expand Down
26 changes: 12 additions & 14 deletions pandas/tests/indexes/interval/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,21 +178,19 @@ def test_set_incompatible_types(self, closed, op_name, sort):
result = set_op(Index([1, 2, 3]), sort=sort)
tm.assert_index_equal(result, expected)

# mixed closed
msg = (
"can only do set operations between two IntervalIndex objects "
"that are closed on the same side and have compatible dtypes"
)
# mixed closed -> cast to object
for other_closed in {"right", "left", "both", "neither"} - {closed}:
other = monotonic_index(0, 11, closed=other_closed)
with pytest.raises(TypeError, match=msg):
set_op(other, sort=sort)
expected = getattr(index.astype(object), op_name)(other, sort=sort)
if op_name == "difference":
expected = index
result = set_op(other, sort=sort)
tm.assert_index_equal(result, expected)

# GH 19016: incompatible dtypes
# GH 19016: incompatible dtypes -> cast to object
other = interval_range(Timestamp("20180101"), periods=9, closed=closed)
msg = (
"can only do set operations between two IntervalIndex objects "
"that are closed on the same side and have compatible dtypes"
)
with pytest.raises(TypeError, match=msg):
set_op(other, sort=sort)
expected = getattr(index.astype(object), op_name)(other, sort=sort)
if op_name == "difference":
expected = index
result = set_op(other, sort=sort)
tm.assert_index_equal(result, expected)