@@ -2644,31 +2644,37 @@ def __post_init__(self):
2644
2644
assert self .array .shape [- 1 ] == self .group_idx .size
2645
2645
2646
2646
2647
- def grouped_scan (inp : AlignedArrays , * , func , axis , dtype = None , keepdims = None ) -> AlignedArrays :
2647
+ def grouped_scan (
2648
+ inp : AlignedArrays , * , func : str , axis , fill_value = None , dtype = None , keepdims = None
2649
+ ) -> AlignedArrays :
2648
2650
assert axis == inp .array .ndim - 1
2649
2651
accumulated = generic_aggregate (
2650
- inp .group_idx , inp .array , axis = axis , engine = "numpy" , func = func , dtype = dtype
2652
+ inp .group_idx ,
2653
+ inp .array ,
2654
+ axis = axis ,
2655
+ engine = "numpy" ,
2656
+ func = func ,
2657
+ dtype = dtype ,
2658
+ fill_value = fill_value ,
2651
2659
)
2652
2660
return AlignedArrays (array = accumulated , group_idx = inp .group_idx )
2653
2661
2654
2662
2655
- def grouped_reduce (
2656
- inp : AlignedArrays , * , func , axis , fill_value = None , dtype = None , keepdims = None
2657
- ) -> AlignedArrays :
2663
+ def grouped_reduce (inp : AlignedArrays , * , agg : Scan , axis : int , keepdims = None ) -> AlignedArrays :
2658
2664
assert axis == inp .array .ndim - 1
2659
2665
reduced = generic_aggregate (
2660
2666
inp .group_idx ,
2661
2667
inp .array ,
2662
2668
axis = axis ,
2663
2669
engine = "numpy" ,
2664
- func = func ,
2665
- dtype = dtype ,
2666
- fill_value = fill_value ,
2670
+ func = agg . reduction ,
2671
+ dtype = inp . array . dtype ,
2672
+ fill_value = agg . binary_op . identity ,
2667
2673
)
2668
2674
return AlignedArrays (array = reduced , group_idx = np .arange (reduced .shape [- 1 ]))
2669
2675
2670
2676
2671
- def grouped_binop (left : AlignedArrays , right : AlignedArrays , op : np . ufunc ) -> AlignedArrays :
2677
+ def grouped_binop (left : AlignedArrays , right : AlignedArrays , op : Callable ) -> AlignedArrays :
2672
2678
reindexed = reindex_ (
2673
2679
left .array ,
2674
2680
from_ = pd .Index (left .group_idx ),
@@ -2708,26 +2714,39 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan):
2708
2714
_zip , by , array , dtype = array .dtype , meta = array ._meta , name = "groupby-scan-preprocess"
2709
2715
)
2710
2716
2717
+ # TODO: move to aggregate_npg.py
2718
+ if agg .name in ["cumsum" , "nancumsum" ]:
2719
+ # https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
2720
+ # it defaults to the dtype of a, unless a
2721
+ # has an integer dtype with a precision less than that of the default platform integer.
2722
+ if array .dtype .kind == "i" :
2723
+ agg .dtype = np .result_type (array .dtype , np .intp )
2724
+ elif array .dtype .kind == "u" :
2725
+ agg .dtype = np .result_type (array .dtype , np .uintp )
2726
+ else :
2727
+ agg .dtype = array .dtype
2728
+ else :
2729
+ agg .dtype = array .dtype
2730
+
2731
+ scan_ = partial (grouped_scan , func = agg .scan , fill_value = agg .identity )
2711
2732
# dask tokenizing error workaround
2712
- scan_ = partial (grouped_scan , func = agg .scan )
2713
2733
scan_ .__name__ = scan_ .func .__name__
2714
2734
2715
2735
# 2. Run the scan
2716
2736
accumulated = scan (
2717
2737
func = scan_ ,
2718
- binop = partial (grouped_binop , op = agg .ufunc ),
2719
- ident = agg .ufunc . identity ,
2738
+ binop = partial (grouped_binop , op = agg .binary_op ),
2739
+ ident = agg .identity ,
2720
2740
x = zipped ,
2721
2741
axis = axis ,
2722
2742
method = "blelloch" ,
2723
- preop = partial (grouped_reduce , func = agg . reduction , fill_value = agg . ufunc . identity ),
2724
- dtype = array .dtype ,
2743
+ preop = partial (grouped_reduce , agg = agg ),
2744
+ dtype = agg .dtype ,
2725
2745
)
2726
2746
2727
2747
# 3. Unzip and extract the final result array, discard groups
2728
- result = map_blocks (extract_array , accumulated , dtype = array .dtype )
2748
+ result = map_blocks (extract_array , accumulated , dtype = agg .dtype )
2729
2749
2730
- assert result .dtype == array .dtype
2731
2750
assert result .chunks == array .chunks
2732
2751
2733
2752
return result
0 commit comments