@@ -820,6 +820,18 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict:
820
820
return results
821
821
822
822
823
+ def _find_unique_groups (x_chunk ):
824
+ from dask .base import flatten
825
+ from dask .utils import deepmap
826
+
827
+ unique_groups = _unique (tuple (flatten (deepmap (listify_groups , x_chunk ))))
828
+ unique_groups = unique_groups [~ isnull (unique_groups )]
829
+
830
+ if len (unique_groups ) == 0 :
831
+ unique_groups = [np .nan ]
832
+ return unique_groups
833
+
834
+
823
835
def _simple_combine (
824
836
x_chunk ,
825
837
agg : Aggregation ,
@@ -839,18 +851,12 @@ def _simple_combine(
839
851
4. At the final agggregate step, we squeeze out DUMMY_AXIS
840
852
"""
841
853
from dask .array .core import deepfirst
842
- from dask .base import flatten
843
854
from dask .utils import deepmap
844
855
845
856
if not reindex :
846
857
# We didn't reindex at the blockwise step
847
858
# So now reindex before combining by reducing along DUMMY_AXIS
848
- unique_groups = _unique (tuple (flatten (deepmap (listify_groups , x_chunk ))))
849
- # print(f"unique_groups: {unique_groups}")
850
- unique_groups = unique_groups [~ isnull (unique_groups )]
851
-
852
- if len (unique_groups ) == 0 :
853
- unique_groups = [np .nan ]
859
+ unique_groups = _find_unique_groups (x_chunk )
854
860
x_chunk = deepmap (
855
861
partial (reindex_intermediates , agg = agg , unique_groups = unique_groups ), x_chunk
856
862
)
@@ -912,7 +918,6 @@ def _grouped_combine(
912
918
sort : bool = True ,
913
919
) -> IntermediateDict :
914
920
"""Combine intermediates step of tree reduction."""
915
- from dask .base import flatten
916
921
from dask .utils import deepmap
917
922
918
923
if isinstance (x_chunk , dict ):
@@ -923,11 +928,7 @@ def _grouped_combine(
923
928
# when there's only a single axis of reduction, we can just concatenate later,
924
929
# reindexing is unnecessary
925
930
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
926
- unique_groups = _unique (tuple (flatten (deepmap (listify_groups , x_chunk ))))
927
- unique_groups = unique_groups [~ isnull (unique_groups )]
928
- if len (unique_groups ) == 0 :
929
- unique_groups = [np .nan ]
930
-
931
+ unique_groups = _find_unique_groups (x_chunk )
931
932
x_chunk = deepmap (
932
933
partial (reindex_intermediates , agg = agg , unique_groups = unique_groups ), x_chunk
933
934
)
0 commit comments