68
68
from numpy .core .numeric import normalize_axis_tuple # type: ignore[no-redef]
69
69
70
70
HAS_NUMBAGG = module_available ("numbagg" , minversion = "0.3.0" )
71
+ HAS_SPARSE = module_available ("sparse" )
71
72
72
73
if TYPE_CHECKING :
73
74
try :
@@ -255,6 +256,12 @@ def _is_bool_supported_reduction(func: T_Agg) -> bool:
255
256
)
256
257
257
258
259
+ def _is_sparse_supported_reduction (func : T_Agg ) -> bool :
260
+ if isinstance (func , Aggregation ):
261
+ func = func .name
262
+ return not HAS_SPARSE or all (f not in func for f in ["first" , "last" , "prod" , "var" , "std" ])
263
+
264
+
258
265
def _get_expected_groups (by : T_By , sort : bool ) -> T_ExpectIndex :
259
266
if is_duck_dask_array (by ):
260
267
raise ValueError ("Please provide expected_groups if not grouping by a numpy array." )
@@ -741,7 +748,7 @@ def reindex_numpy(array, from_, to, fill_value, dtype, axis):
741
748
indexer = [slice (None , None )] * array .ndim
742
749
indexer [axis ] = idx
743
750
reindexed = array [tuple (indexer )]
744
- if any (idx == - 1 ):
751
+ if (idx == - 1 ). any ( ):
745
752
if fill_value is None :
746
753
raise ValueError ("Filling is required. fill_value cannot be None." )
747
754
indexer [axis ] = idx == - 1
@@ -755,20 +762,36 @@ def reindex_pydata_sparse_coo(array, from_, to, fill_value, dtype, axis):
755
762
756
763
assert axis == - 1
757
764
758
- if fill_value is None :
759
- raise ValueError ("Filling is required for sparse arrays. fill_value cannot be None." )
765
+ needs_reindex = (from_ .get_indexer (to ) == - 1 ).any ()
766
+ if needs_reindex and fill_value is None :
767
+ raise ValueError ("Filling is required. fill_value cannot be None." )
768
+
760
769
idx = to .get_indexer (from_ )
761
- mask = idx != - 1
770
+ mask = idx != - 1 # indices along last axis to keep
762
771
shape = array .shape
763
- ranges = np .broadcast_arrays (* np .ix_ (* (tuple (np .arange (size ) for size in shape [:axis ]) + (idx [mask ],))))
764
- coords = np .stack (ranges , axis = 0 ).reshape (array .ndim , - 1 )
765
772
766
- data = array [..., mask ].data if isinstance (array , sparse .COO ) else array [..., mask ].reshape (- 1 )
773
+ if isinstance (array , sparse .COO ):
774
+ subset = array [..., mask ]
775
+ data = subset .data
776
+ coords = subset .coords
777
+ if subset .nnz > 0 :
778
+ coords [- 1 , :] = coords [- 1 , :][idx [mask ]]
779
+ if fill_value is None :
780
+ # no reindexing is actually needed (dense case)
781
+ # preserve the fill_value
782
+ fill_value = array .fill_value
783
+ else :
784
+ ranges = np .broadcast_arrays (
785
+ * np .ix_ (* (tuple (np .arange (size ) for size in shape [:axis ]) + (idx [mask ],)))
786
+ )
787
+ coords = np .stack (ranges , axis = 0 ).reshape (array .ndim , - 1 )
788
+ data = array [..., mask ].reshape (- 1 )
767
789
768
790
reindexed = sparse .COO (
769
791
coords = coords ,
770
792
data = data .astype (dtype , copy = False ),
771
793
shape = (* array .shape [:axis ], to .size ),
794
+ fill_value = fill_value ,
772
795
)
773
796
774
797
return reindexed
@@ -795,7 +818,11 @@ def reindex_(
795
818
796
819
if array .shape [axis ] == 0 :
797
820
# all groups were NaN
798
- reindexed = np .full (array .shape [:- 1 ] + (len (to ),), fill_value , dtype = array .dtype )
821
+ shape = array .shape [:- 1 ] + (len (to ),)
822
+ if array_type in (ReindexArrayType .AUTO , ReindexArrayType .NUMPY ):
823
+ reindexed = np .full (shape , fill_value , dtype = array .dtype )
824
+ else :
825
+ raise NotImplementedError
799
826
return reindexed
800
827
801
828
from_ = pd .Index (from_ )
@@ -1288,7 +1315,7 @@ def _finalize_results(
1288
1315
fill_value = agg .fill_value ["user" ]
1289
1316
if min_count > 0 :
1290
1317
count_mask = counts < min_count
1291
- if count_mask .any ():
1318
+ if count_mask .any () or reindex . array_type is ReindexArrayType . SPARSE_COO :
1292
1319
# For one count_mask.any() prevents promoting bool to dtype(fill_value) unless
1293
1320
# necessary
1294
1321
if fill_value is None :
@@ -2815,6 +2842,12 @@ def groupby_reduce(
2815
2842
array .dtype ,
2816
2843
)
2817
2844
2845
+ if reindex .array_type is ReindexArrayType .SPARSE_COO and not _is_sparse_supported_reduction (func ):
2846
+ raise NotImplementedError (
2847
+ f"Aggregation { func = !r} is not supported when reindexing to a sparse array. "
2848
+ "Please raise an issue"
2849
+ )
2850
+
2818
2851
if TYPE_CHECKING :
2819
2852
assert isinstance (reindex , ReindexStrategy )
2820
2853
assert method is not None
0 commit comments