Skip to content

Commit 3fe6149

Browse files
authored
REF: remove test_accumulate_series_raises (#54367)
* REF: remove test_accumulate_series_raises * mypy fixup
1 parent 263828c commit 3fe6149

File tree

6 files changed

+41
-58
lines changed

6 files changed

+41
-58
lines changed

pandas/tests/extension/base/accumulate.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,32 @@ class BaseAccumulateTests(BaseExtensionTests):
1111
make sense for numeric/boolean operations.
1212
"""
1313

14-
def check_accumulate(self, s, op_name, skipna):
15-
result = getattr(s, op_name)(skipna=skipna)
14+
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
15+
# Do we expect this accumulation to be supported for this dtype?
16+
# We default to assuming "no"; subclass authors should override here.
17+
return False
18+
19+
def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
20+
alt = ser.astype("float64")
21+
result = getattr(ser, op_name)(skipna=skipna)
1622

1723
if result.dtype == pd.Float32Dtype() and op_name == "cumprod" and skipna:
24+
# TODO: avoid special-casing here
1825
pytest.skip(
1926
f"Float32 precision lead to large differences with op {op_name} "
2027
f"and skipna={skipna}"
2128
)
2229

23-
expected = getattr(s.astype("float64"), op_name)(skipna=skipna)
30+
expected = getattr(alt, op_name)(skipna=skipna)
2431
tm.assert_series_equal(result, expected, check_dtype=False)
2532

26-
@pytest.mark.parametrize("skipna", [True, False])
27-
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
28-
op_name = all_numeric_accumulations
29-
ser = pd.Series(data)
30-
31-
with pytest.raises(NotImplementedError):
32-
getattr(ser, op_name)(skipna=skipna)
33-
3433
@pytest.mark.parametrize("skipna", [True, False])
3534
def test_accumulate_series(self, data, all_numeric_accumulations, skipna):
3635
op_name = all_numeric_accumulations
3736
ser = pd.Series(data)
38-
self.check_accumulate(ser, op_name, skipna)
37+
38+
if self._supports_accumulation(ser, op_name):
39+
self.check_accumulate(ser, op_name, skipna)
40+
else:
41+
with pytest.raises(NotImplementedError):
42+
getattr(ser, op_name)(skipna=skipna)

pandas/tests/extension/test_arrow.py

+18-33
Original file line numberDiff line numberDiff line change
@@ -354,47 +354,32 @@ def check_accumulate(self, ser, op_name, skipna):
354354
expected = getattr(ser.astype("Float64"), op_name)(skipna=skipna)
355355
tm.assert_series_equal(result, expected, check_dtype=False)
356356

357-
@pytest.mark.parametrize("skipna", [True, False])
358-
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
359-
pa_type = data.dtype.pyarrow_dtype
360-
if (
361-
(
362-
pa.types.is_integer(pa_type)
363-
or pa.types.is_floating(pa_type)
364-
or pa.types.is_duration(pa_type)
365-
)
366-
and all_numeric_accumulations == "cumsum"
367-
and not pa_version_under9p0
368-
):
369-
pytest.skip("These work, are tested by test_accumulate_series.")
357+
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
358+
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no
359+
# attribute "pyarrow_dtype"
360+
pa_type = ser.dtype.pyarrow_dtype # type: ignore[union-attr]
370361

371-
op_name = all_numeric_accumulations
372-
ser = pd.Series(data)
373-
374-
with pytest.raises(NotImplementedError):
375-
getattr(ser, op_name)(skipna=skipna)
376-
377-
@pytest.mark.parametrize("skipna", [True, False])
378-
def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request):
379-
pa_type = data.dtype.pyarrow_dtype
380-
op_name = all_numeric_accumulations
381-
ser = pd.Series(data)
382-
383-
do_skip = False
384362
if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type):
385363
if op_name in ["cumsum", "cumprod"]:
386-
do_skip = True
364+
return False
387365
elif pa.types.is_temporal(pa_type) and not pa.types.is_duration(pa_type):
388366
if op_name in ["cumsum", "cumprod"]:
389-
do_skip = True
367+
return False
390368
elif pa.types.is_duration(pa_type):
391369
if op_name == "cumprod":
392-
do_skip = True
370+
return False
371+
return True
393372

394-
if do_skip:
395-
pytest.skip(
396-
f"{op_name} should *not* work, we test in "
397-
"test_accumulate_series_raises that these correctly raise."
373+
@pytest.mark.parametrize("skipna", [True, False])
374+
def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request):
375+
pa_type = data.dtype.pyarrow_dtype
376+
op_name = all_numeric_accumulations
377+
ser = pd.Series(data)
378+
379+
if not self._supports_accumulation(ser, op_name):
380+
# The base class test will check that we raise
381+
return super().test_accumulate_series(
382+
data, all_numeric_accumulations, skipna
398383
)
399384

400385
if all_numeric_accumulations != "cumsum" or pa_version_under9p0:

pandas/tests/extension/test_boolean.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,9 @@ class TestUnaryOps(base.BaseUnaryOpsTests):
274274

275275

276276
class TestAccumulation(base.BaseAccumulateTests):
277+
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
278+
return True
279+
277280
def check_accumulate(self, s, op_name, skipna):
278281
length = 64
279282
if not IS64 or is_platform_windows():
@@ -288,10 +291,6 @@ def check_accumulate(self, s, op_name, skipna):
288291
expected = expected.astype("boolean")
289292
tm.assert_series_equal(result, expected)
290293

291-
@pytest.mark.parametrize("skipna", [True, False])
292-
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
293-
pass
294-
295294

296295
class TestParsing(base.BaseParsingTests):
297296
pass

pandas/tests/extension/test_categorical.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,7 @@ class TestReduce(base.BaseNoReduceTests):
157157

158158

159159
class TestAccumulate(base.BaseAccumulateTests):
160-
@pytest.mark.parametrize("skipna", [True, False])
161-
def test_accumulate_series(self, data, all_numeric_accumulations, skipna):
162-
pass
160+
pass
163161

164162

165163
class TestMethods(base.BaseMethodsTests):

pandas/tests/extension/test_masked_numeric.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,8 @@ class TestBooleanReduce(base.BaseBooleanReduceTests):
290290

291291

292292
class TestAccumulation(base.BaseAccumulateTests):
293-
@pytest.mark.parametrize("skipna", [True, False])
294-
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
295-
pass
293+
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
294+
return True
296295

297296
def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
298297
# overwrite to ensure pd.NA is tested instead of np.nan

pandas/tests/extension/test_sparse.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,4 @@ def test_EA_types(self, engine, data):
476476

477477

478478
class TestNoNumericAccumulations(base.BaseAccumulateTests):
479-
@pytest.mark.parametrize("skipna", [True, False])
480-
def test_accumulate_series(self, data, all_numeric_accumulations, skipna):
481-
pass
479+
pass

0 commit comments

Comments
 (0)