@@ -354,47 +354,32 @@ def check_accumulate(self, ser, op_name, skipna):
354
354
expected = getattr (ser .astype ("Float64" ), op_name )(skipna = skipna )
355
355
tm .assert_series_equal (result , expected , check_dtype = False )
356
356
357
- @pytest .mark .parametrize ("skipna" , [True , False ])
358
- def test_accumulate_series_raises (self , data , all_numeric_accumulations , skipna ):
359
- pa_type = data .dtype .pyarrow_dtype
360
- if (
361
- (
362
- pa .types .is_integer (pa_type )
363
- or pa .types .is_floating (pa_type )
364
- or pa .types .is_duration (pa_type )
365
- )
366
- and all_numeric_accumulations == "cumsum"
367
- and not pa_version_under9p0
368
- ):
369
- pytest .skip ("These work, are tested by test_accumulate_series." )
357
+ def _supports_accumulation (self , ser : pd .Series , op_name : str ) -> bool :
358
+ # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no
359
+ # attribute "pyarrow_dtype"
360
+ pa_type = ser .dtype .pyarrow_dtype # type: ignore[union-attr]
370
361
371
- op_name = all_numeric_accumulations
372
- ser = pd .Series (data )
373
-
374
- with pytest .raises (NotImplementedError ):
375
- getattr (ser , op_name )(skipna = skipna )
376
-
377
- @pytest .mark .parametrize ("skipna" , [True , False ])
378
- def test_accumulate_series (self , data , all_numeric_accumulations , skipna , request ):
379
- pa_type = data .dtype .pyarrow_dtype
380
- op_name = all_numeric_accumulations
381
- ser = pd .Series (data )
382
-
383
- do_skip = False
384
362
if pa .types .is_string (pa_type ) or pa .types .is_binary (pa_type ):
385
363
if op_name in ["cumsum" , "cumprod" ]:
386
- do_skip = True
364
+ return False
387
365
elif pa .types .is_temporal (pa_type ) and not pa .types .is_duration (pa_type ):
388
366
if op_name in ["cumsum" , "cumprod" ]:
389
- do_skip = True
367
+ return False
390
368
elif pa .types .is_duration (pa_type ):
391
369
if op_name == "cumprod" :
392
- do_skip = True
370
+ return False
371
+ return True
393
372
394
- if do_skip :
395
- pytest .skip (
396
- f"{ op_name } should *not* work, we test in "
397
- "test_accumulate_series_raises that these correctly raise."
373
+ @pytest .mark .parametrize ("skipna" , [True , False ])
374
+ def test_accumulate_series (self , data , all_numeric_accumulations , skipna , request ):
375
+ pa_type = data .dtype .pyarrow_dtype
376
+ op_name = all_numeric_accumulations
377
+ ser = pd .Series (data )
378
+
379
+ if not self ._supports_accumulation (ser , op_name ):
380
+ # The base class test will check that we raise
381
+ return super ().test_accumulate_series (
382
+ data , all_numeric_accumulations , skipna
398
383
)
399
384
400
385
if all_numeric_accumulations != "cumsum" or pa_version_under9p0 :
0 commit comments