Skip to content

Commit 8c4ef78

Browse files
authored
annotate (#37917)
1 parent e0547d1 commit 8c4ef78

File tree

5 files changed

+16
-11
lines changed

5 files changed

+16
-11
lines changed

pandas/core/indexes/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2734,14 +2734,16 @@ def _union(self, other, sort):
27342734
stacklevel=3,
27352735
)
27362736

2737-
return self._shallow_copy(result)
2737+
return result
27382738

27392739
@final
27402740
def _wrap_setop_result(self, other, result):
27412741
if isinstance(self, (ABCDatetimeIndex, ABCTimedeltaIndex)) and isinstance(
27422742
result, np.ndarray
27432743
):
27442744
result = type(self._data)._simple_new(result, dtype=self.dtype)
2745+
elif is_categorical_dtype(self.dtype) and isinstance(result, np.ndarray):
2746+
result = Categorical(result, dtype=self.dtype)
27452747

27462748
name = get_op_result_name(self, other)
27472749
if isinstance(result, Index):

pandas/core/indexes/category.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, List
1+
from typing import Any, List, Optional
22
import warnings
33

44
import numpy as np
@@ -227,10 +227,15 @@ def _simple_new(cls, values: Categorical, name: Label = None):
227227
# --------------------------------------------------------------------
228228

229229
@doc(Index._shallow_copy)
230-
def _shallow_copy(self, values=None, name: Label = no_default):
230+
def _shallow_copy(
231+
self, values: Optional[Categorical] = None, name: Label = no_default
232+
):
231233
name = self.name if name is no_default else name
232234

233235
if values is not None:
236+
# In tests we only get here with Categorical objects that
237+
# have matching .ordered, and values.categories a subset of
238+
# our own. However we do _not_ have a dtype match in general.
234239
values = Categorical(values, dtype=self.dtype)
235240

236241
return super()._shallow_copy(values=values, name=name)

pandas/core/indexes/datetimelike.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -765,11 +765,7 @@ def intersection(self, other, sort=False):
765765
start = right[0]
766766

767767
if end < start:
768-
# pandas\core\indexes\datetimelike.py:758: error: Unexpected
769-
# keyword argument "freq" for "DatetimeTimedeltaMixin" [call-arg]
770-
result = type(self)(
771-
data=[], dtype=self.dtype, freq=self.freq # type: ignore[call-arg]
772-
)
768+
result = self[:0]
773769
else:
774770
lslice = slice(*left.slice_locs(start, end))
775771
left_chunk = left._values[lslice]

pandas/core/indexes/range.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pandas.core.construction import extract_array
3030
import pandas.core.indexes.base as ibase
3131
from pandas.core.indexes.base import _index_shared_docs, maybe_extract_name
32-
from pandas.core.indexes.numeric import Int64Index
32+
from pandas.core.indexes.numeric import Float64Index, Int64Index
3333
from pandas.core.ops.common import unpack_zerodim_and_defer
3434

3535
_empty_range = range(0)
@@ -397,6 +397,8 @@ def _shallow_copy(self, values=None, name: Label = no_default):
397397
name = self.name if name is no_default else name
398398

399399
if values is not None:
400+
if values.dtype.kind == "f":
401+
return Float64Index(values, name=name)
400402
return Int64Index._simple_new(values, name=name)
401403

402404
result = self._simple_new(self._range, name=name)

pandas/tests/indexes/base_class/test_setops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ def test_setops_preserve_object_dtype(self):
3131

3232
result = idx._union(idx[1:], sort=None)
3333
expected = idx
34-
tm.assert_index_equal(result, expected)
34+
tm.assert_numpy_array_equal(result, expected.values)
3535

3636
result = idx.union(idx[1:], sort=None)
3737
tm.assert_index_equal(result, expected)
3838

3939
# if other is not monotonic increasing, _union goes through
4040
# a different route
4141
result = idx._union(idx[1:][::-1], sort=None)
42-
tm.assert_index_equal(result, expected)
42+
tm.assert_numpy_array_equal(result, expected.values)
4343

4444
result = idx.union(idx[1:][::-1], sort=None)
4545
tm.assert_index_equal(result, expected)

0 commit comments

Comments
 (0)