Skip to content

Commit c10c999

Browse files
committed
Small cleanup
1 parent 4b82d8e commit c10c999

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

flox/core.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,18 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict:
820820
return results
821821

822822

823+
def _find_unique_groups(x_chunk):
824+
from dask.base import flatten
825+
from dask.utils import deepmap
826+
827+
unique_groups = _unique(tuple(flatten(deepmap(listify_groups, x_chunk))))
828+
unique_groups = unique_groups[~isnull(unique_groups)]
829+
830+
if len(unique_groups) == 0:
831+
unique_groups = [np.nan]
832+
return unique_groups
833+
834+
823835
def _simple_combine(
824836
x_chunk,
825837
agg: Aggregation,
@@ -839,18 +851,12 @@ def _simple_combine(
839851
4. At the final agggregate step, we squeeze out DUMMY_AXIS
840852
"""
841853
from dask.array.core import deepfirst
842-
from dask.base import flatten
843854
from dask.utils import deepmap
844855

845856
if not reindex:
846857
# We didn't reindex at the blockwise step
847858
# So now reindex before combining by reducing along DUMMY_AXIS
848-
unique_groups = _unique(tuple(flatten(deepmap(listify_groups, x_chunk))))
849-
# print(f"unique_groups: {unique_groups}")
850-
unique_groups = unique_groups[~isnull(unique_groups)]
851-
852-
if len(unique_groups) == 0:
853-
unique_groups = [np.nan]
859+
unique_groups = _find_unique_groups(x_chunk)
854860
x_chunk = deepmap(
855861
partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
856862
)
@@ -912,7 +918,6 @@ def _grouped_combine(
912918
sort: bool = True,
913919
) -> IntermediateDict:
914920
"""Combine intermediates step of tree reduction."""
915-
from dask.base import flatten
916921
from dask.utils import deepmap
917922

918923
if isinstance(x_chunk, dict):
@@ -923,11 +928,7 @@ def _grouped_combine(
923928
# when there's only a single axis of reduction, we can just concatenate later,
924929
# reindexing is unnecessary
925930
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
926-
unique_groups = _unique(tuple(flatten(deepmap(listify_groups, x_chunk))))
927-
unique_groups = unique_groups[~isnull(unique_groups)]
928-
if len(unique_groups) == 0:
929-
unique_groups = [np.nan]
930-
931+
unique_groups = _find_unique_groups(x_chunk)
931932
x_chunk = deepmap(
932933
partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
933934
)

0 commit comments

Comments
 (0)