68
68
from numpy .core .numeric import normalize_axis_tuple # type: ignore[no-redef]
69
69
70
70
HAS_NUMBAGG = module_available ("numbagg" , minversion = "0.3.0" )
71
+ HAS_SPARSE = module_available ("sparse" )
71
72
72
73
if TYPE_CHECKING :
73
74
try :
@@ -255,6 +256,12 @@ def _is_bool_supported_reduction(func: T_Agg) -> bool:
255
256
)
256
257
257
258
259
+ def _is_sparse_supported_reduction (func : T_Agg ) -> bool :
260
+ if isinstance (func , Aggregation ):
261
+ func = func .name
262
+ return not HAS_SPARSE or all (f not in func for f in ["first" , "last" , "prod" , "var" , "std" ])
263
+
264
+
258
265
def _get_expected_groups (by : T_By , sort : bool ) -> T_ExpectIndex :
259
266
if is_duck_dask_array (by ):
260
267
raise ValueError ("Please provide expected_groups if not grouping by a numpy array." )
@@ -736,12 +743,12 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
736
743
return array .rechunk ({axis : newchunks })
737
744
738
745
739
- def reindex_numpy (array , from_ , to , fill_value , dtype , axis ):
746
+ def reindex_numpy (array , from_ : pd . Index , to : pd . Index , fill_value , dtype , axis : int ):
740
747
idx = from_ .get_indexer (to )
741
748
indexer = [slice (None , None )] * array .ndim
742
749
indexer [axis ] = idx
743
750
reindexed = array [tuple (indexer )]
744
- if any (idx == - 1 ):
751
+ if (idx == - 1 ). any ( ):
745
752
if fill_value is None :
746
753
raise ValueError ("Filling is required. fill_value cannot be None." )
747
754
indexer [axis ] = idx == - 1
@@ -750,25 +757,43 @@ def reindex_numpy(array, from_, to, fill_value, dtype, axis):
750
757
return reindexed
751
758
752
759
753
- def reindex_pydata_sparse_coo (array , from_ , to , fill_value , dtype , axis ):
760
+ def reindex_pydata_sparse_coo (array , from_ : pd . Index , to : pd . Index , fill_value , dtype , axis : int ):
754
761
import sparse
755
762
756
763
assert axis == - 1
757
764
758
- if fill_value is None :
759
- raise ValueError ("Filling is required for sparse arrays. fill_value cannot be None." )
765
+ needs_reindex = (from_ .difference (to )).size > 0
766
+ if needs_reindex and fill_value is None :
767
+ raise ValueError ("Filling is required. fill_value cannot be None." )
768
+
760
769
idx = to .get_indexer (from_ )
761
- mask = idx != - 1
770
+ mask = idx != - 1 # indices along last axis to keep
771
+ if mask .all ():
772
+ mask = slice (None )
762
773
shape = array .shape
763
- ranges = np .broadcast_arrays (* np .ix_ (* (tuple (np .arange (size ) for size in shape [:axis ]) + (idx [mask ],))))
764
- coords = np .stack (ranges , axis = 0 ).reshape (array .ndim , - 1 )
765
774
766
- data = array [..., mask ].data if isinstance (array , sparse .COO ) else array [..., mask ].reshape (- 1 )
775
+ if isinstance (array , sparse .COO ):
776
+ subset = array [..., mask ]
777
+ data = subset .data
778
+ coords = subset .coords
779
+ if subset .nnz > 0 :
780
+ coords [- 1 , :] = idx [mask ][coords [- 1 , :]]
781
+ if fill_value is None :
782
+ # no reindexing is actually needed (dense case)
783
+ # preserve the fill_value
784
+ fill_value = array .fill_value
785
+ else :
786
+ ranges = np .broadcast_arrays (
787
+ * np .ix_ (* (tuple (np .arange (size ) for size in shape [:axis ]) + (idx [mask ],)))
788
+ )
789
+ coords = np .stack (ranges , axis = 0 ).reshape (array .ndim , - 1 )
790
+ data = array [..., mask ].reshape (- 1 )
767
791
768
792
reindexed = sparse .COO (
769
793
coords = coords ,
770
794
data = data .astype (dtype , copy = False ),
771
795
shape = (* array .shape [:axis ], to .size ),
796
+ fill_value = fill_value ,
772
797
)
773
798
774
799
return reindexed
@@ -795,7 +820,11 @@ def reindex_(
795
820
796
821
if array .shape [axis ] == 0 :
797
822
# all groups were NaN
798
- reindexed = np .full (array .shape [:- 1 ] + (len (to ),), fill_value , dtype = array .dtype )
823
+ shape = array .shape [:- 1 ] + (len (to ),)
824
+ if array_type in (ReindexArrayType .AUTO , ReindexArrayType .NUMPY ):
825
+ reindexed = np .full (shape , fill_value , dtype = array .dtype )
826
+ else :
827
+ raise NotImplementedError
799
828
return reindexed
800
829
801
830
from_ = pd .Index (from_ )
@@ -1044,7 +1073,7 @@ def chunk_argreduce(
1044
1073
sort = sort ,
1045
1074
user_dtype = user_dtype ,
1046
1075
)
1047
- if not isnull (results ["groups" ]). all ( ):
1076
+ if not all ( isnull (results ["groups" ])):
1048
1077
idx = np .broadcast_to (idx , array .shape )
1049
1078
1050
1079
# array, by get flattened to 1D before passing to npg
@@ -1288,7 +1317,7 @@ def _finalize_results(
1288
1317
fill_value = agg .fill_value ["user" ]
1289
1318
if min_count > 0 :
1290
1319
count_mask = counts < min_count
1291
- if count_mask .any ():
1320
+ if count_mask .any () or reindex . array_type is ReindexArrayType . SPARSE_COO :
1292
1321
# For one count_mask.any() prevents promoting bool to dtype(fill_value) unless
1293
1322
# necessary
1294
1323
if fill_value is None :
@@ -2815,6 +2844,12 @@ def groupby_reduce(
2815
2844
array .dtype ,
2816
2845
)
2817
2846
2847
+ if reindex .array_type is ReindexArrayType .SPARSE_COO and not _is_sparse_supported_reduction (func ):
2848
+ raise NotImplementedError (
2849
+ f"Aggregation { func = !r} is not supported when reindexing to a sparse array. "
2850
+ "Please raise an issue"
2851
+ )
2852
+
2818
2853
if TYPE_CHECKING :
2819
2854
assert isinstance (reindex , ReindexStrategy )
2820
2855
assert method is not None
0 commit comments