Skip to content

Commit a3ec4e5

Browse files
committed
Add more numeric tests for NumIndex
1 parent 0539d8c commit a3ec4e5

File tree

4 files changed

+74
-20
lines changed

4 files changed

+74
-20
lines changed

pandas/core/indexes/base.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def __new__(
442442
return Index._simple_new(data, name=name)
443443

444444
# index-like
445-
elif type(data) is NumIndex and dtype is None:
445+
elif isinstance(data, NumIndex) and data._is_num_index() and dtype is None:
446446
return NumIndex(data, name=name, copy=copy)
447447
elif isinstance(data, (np.ndarray, Index, ABCSeries)):
448448

@@ -2407,6 +2407,26 @@ def is_all_dates(self) -> bool:
24072407
)
24082408
return self._is_all_dates
24092409

2410+
def _is_num_index(self) -> bool:
2411+
"""
2412+
Whether self is a NumIndex, but not *not* Int64Index, UInt64Index, FloatIndex.
2413+
2414+
Typically used to check if an operation should return NumIndex or plain Index.
2415+
"""
2416+
from pandas.core.indexes.numeric import (
2417+
Float64Index,
2418+
Int64Index,
2419+
NumIndex,
2420+
UInt64Index,
2421+
)
2422+
2423+
if not isinstance(self, NumIndex):
2424+
return False
2425+
elif isinstance(self, (Int64Index, UInt64Index, Float64Index)):
2426+
return False
2427+
else:
2428+
return True
2429+
24102430
# --------------------------------------------------------------------
24112431
# Pickle Methods
24122432

@@ -5488,8 +5508,8 @@ def map(self, mapper, na_action=None):
54885508
# empty
54895509
attributes["dtype"] = self.dtype
54905510

5491-
if type(self) is NumIndex:
5492-
return type(self)(new_values, **attributes)
5511+
if self._is_num_index() and issubclass(new_values.dtype.type, np.number):
5512+
return NumIndex(new_values, **attributes)
54935513

54945514
return Index(new_values, **attributes)
54955515

pandas/core/indexes/numeric.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -288,11 +288,11 @@ def astype(self, dtype, copy=True):
288288
# TODO(jreback); this can change once we have an EA Index type
289289
# GH 13149
290290
arr = astype_nansafe(self._values, dtype=dtype)
291-
if isinstance(self, Float64Index):
291+
if not self._is_num_index():
292292
return Int64Index(arr, name=self.name)
293293
else:
294294
return NumIndex(arr, name=self.name, dtype=dtype)
295-
elif is_categorical_dtype(dtype):
295+
if is_categorical_dtype(dtype):
296296
from pandas import CategoricalIndex
297297

298298
return CategoricalIndex(self, name=self.name, dtype=dtype, copy=copy)

pandas/tests/indexes/common.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pandas.core.dtypes.dtypes import CategoricalDtype
1313

1414
import pandas as pd
15-
from pandas import ( # noqa
15+
from pandas import (
1616
CategoricalIndex,
1717
DatetimeIndex,
1818
Float64Index,
@@ -29,6 +29,7 @@
2929
)
3030
import pandas._testing as tm
3131
from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin
32+
from pandas.core.indexes.numeric import NumIndex
3233

3334

3435
class Base:
@@ -343,12 +344,13 @@ def test_numpy_argsort(self, index):
343344
def test_repeat(self, simple_index):
344345
rep = 2
345346
idx = simple_index.copy()
346-
expected = Index(idx.values.repeat(rep), name=idx.name)
347+
new_index_cls = type(idx) if not isinstance(idx, RangeIndex) else Int64Index
348+
expected = new_index_cls(idx.values.repeat(rep), name=idx.name)
347349
tm.assert_index_equal(idx.repeat(rep), expected)
348350

349351
idx = simple_index
350352
rep = np.arange(len(idx))
351-
expected = Index(idx.values.repeat(rep), name=idx.name)
353+
expected = new_index_cls(idx.values.repeat(rep), name=idx.name)
352354
tm.assert_index_equal(idx.repeat(rep), expected)
353355

354356
def test_numpy_repeat(self, simple_index):
@@ -649,7 +651,12 @@ def test_map_dictlike(self, mapper, simple_index):
649651
tm.assert_index_equal(result, expected)
650652

651653
# empty mappable
652-
expected = Index([np.nan] * len(idx))
654+
if idx._is_num_index():
655+
new_index_cls = NumIndex
656+
else:
657+
new_index_cls = Float64Index
658+
659+
expected = new_index_cls([np.nan] * len(idx))
653660
result = idx.map(mapper(expected, idx))
654661
tm.assert_index_equal(result, expected)
655662

pandas/tests/indexes/numeric/test_numeric.py

+38-11
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,25 @@
1313
UInt64Index,
1414
)
1515
import pandas._testing as tm
16+
from pandas.core.indexes.numeric import NumIndex
1617
from pandas.tests.indexes.common import NumericBase
1718

1819

1920
class TestFloat64Index(NumericBase):
2021
_index_cls = Float64Index
2122
_dtype = np.float64
2223

23-
@pytest.fixture
24-
def simple_index(self) -> Index:
25-
values = np.arange(5, dtype=self._dtype)
26-
return self._index_cls(values)
24+
@pytest.fixture(
25+
params=[
26+
(Float64Index, None),
27+
(NumIndex, np.float64),
28+
(NumIndex, np.float32),
29+
],
30+
)
31+
def simple_index(self, request) -> Index:
32+
index_cls, dtype = request.param
33+
values = np.arange(5, dtype=dtype)
34+
return index_cls(values)
2735

2836
@pytest.fixture(
2937
params=[
@@ -392,9 +400,19 @@ class TestInt64Index(NumericInt):
392400
_index_cls = Int64Index
393401
_dtype = np.int64
394402

395-
@pytest.fixture
396-
def simple_index(self) -> Index:
397-
return self._index_cls(range(0, 20, 2), dtype=self._dtype)
403+
@pytest.fixture(
404+
params=[
405+
(Int64Index, None),
406+
(NumIndex, np.int64),
407+
(NumIndex, np.int32),
408+
(NumIndex, np.int16),
409+
(NumIndex, np.int8),
410+
],
411+
)
412+
def simple_index(self, request) -> Index:
413+
index_cls, dtype = request.param
414+
values = np.arange(5, dtype=dtype)
415+
return index_cls(values)
398416

399417
@pytest.fixture(
400418
params=[range(0, 20, 2), range(19, -1, -1)], ids=["index_inc", "index_dec"]
@@ -490,10 +508,19 @@ class TestUInt64Index(NumericInt):
490508
_index_cls = UInt64Index
491509
_dtype = np.uint64
492510

493-
@pytest.fixture
494-
def simple_index(self) -> Index:
495-
# compat with shared Int64/Float64 tests
496-
return self._index_cls(np.arange(5, dtype=self._dtype))
511+
@pytest.fixture(
512+
params=[
513+
(UInt64Index, None),
514+
(NumIndex, np.uint64),
515+
(NumIndex, np.uint32),
516+
(NumIndex, np.uint16),
517+
(NumIndex, np.uint8),
518+
],
519+
)
520+
def simple_index(self, request) -> Index:
521+
index_cls, dtype = request.param
522+
values = np.arange(5, dtype=dtype)
523+
return index_cls(values)
497524

498525
@pytest.fixture(
499526
params=[

0 commit comments

Comments
 (0)