Skip to content

Commit cb25e38

Browse files
committed
Try return_array from _finalize_results
1 parent ee6be26 commit cb25e38

File tree

1 file changed

+43
-16
lines changed

1 file changed

+43
-16
lines changed

flox/core.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -803,10 +803,15 @@ def _aggregate(
803803
keepdims,
804804
fill_value: Any,
805805
reindex: bool,
806+
return_array: bool,
806807
) -> FinalResultsDict:
807808
"""Final aggregation step of tree reduction"""
808809
results = combine(x_chunk, agg, axis, keepdims, is_aggregate=True)
809-
return _finalize_results(results, agg, axis, expected_groups, fill_value, reindex)
810+
finalized = _finalize_results(results, agg, axis, expected_groups, fill_value, reindex)
811+
if return_array:
812+
return finalized[agg.name]
813+
else:
814+
return finalized
810815

811816

812817
def _expand_dims(results: IntermediateDict) -> IntermediateDict:
@@ -1287,6 +1292,7 @@ def dask_groupby_agg(
12871292
group_chunks: tuple[tuple[Union[int, float], ...]] = (
12881293
(len(expected_groups),) if expected_groups is not None else (np.nan,),
12891294
)
1295+
groups_are_unknown = is_duck_dask_array(by_input) and expected_groups is None
12901296

12911297
if method in ["map-reduce", "cohorts"]:
12921298
combine: Callable[..., IntermediateDict]
@@ -1316,16 +1322,32 @@ def dask_groupby_agg(
13161322
reduced = tree_reduce(
13171323
intermediate,
13181324
combine=partial(combine, agg=agg),
1319-
aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex),
1325+
aggregate=partial(
1326+
aggregate,
1327+
expected_groups=expected_groups,
1328+
reindex=reindex,
1329+
return_array=not groups_are_unknown,
1330+
),
13201331
)
1321-
if is_duck_dask_array(by_input) and expected_groups is None:
1332+
if groups_are_unknown:
13221333
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
1334+
result = dask.array.map_blocks(
1335+
_extract_result,
1336+
reduced,
1337+
chunks=reduced.chunks[: -len(axis)] + group_chunks,
1338+
drop_axis=axis[:-1],
1339+
dtype=agg.dtype[agg.name],
1340+
key=agg.name,
1341+
name=f"{name}-{token}",
1342+
)
1343+
13231344
else:
13241345
if expected_groups is None:
13251346
expected_groups_ = _get_expected_groups(by_input, sort=sort)
13261347
else:
13271348
expected_groups_ = expected_groups
13281349
groups = (expected_groups_.to_numpy(),)
1350+
result = reduced
13291351

13301352
elif method == "cohorts":
13311353
chunks_cohorts = find_group_cohorts(
@@ -1344,12 +1366,14 @@ def dask_groupby_agg(
13441366
tree_reduce(
13451367
reindexed,
13461368
combine=partial(combine, agg=agg, reindex=True),
1347-
aggregate=partial(aggregate, expected_groups=index, reindex=True),
1369+
aggregate=partial(
1370+
aggregate, expected_groups=index, reindex=True, return_array=True
1371+
),
13481372
)
13491373
)
13501374
groups_.append(cohort)
13511375

1352-
reduced = dask.array.concatenate(reduced_, axis=-1)
1376+
result = dask.array.concatenate(reduced_, axis=-1)
13531377
groups = (np.concatenate(groups_),)
13541378
group_chunks = (tuple(len(cohort) for cohort in groups_),)
13551379

@@ -1375,21 +1399,24 @@ def dask_groupby_agg(
13751399
for ax, chunks in zip(axis, group_chunks):
13761400
adjust_chunks[ax] = chunks
13771401

1378-
result = dask.array.blockwise(
1379-
_extract_result,
1380-
inds[: -len(axis)] + (inds[-1],),
1381-
reduced,
1382-
inds,
1383-
adjust_chunks=adjust_chunks,
1384-
dtype=agg.dtype[agg.name],
1385-
key=agg.name,
1386-
name=f"{name}-{token}",
1387-
)
1388-
1402+
# result = dask.array.blockwise(
1403+
# _extract_result,
1404+
# inds[: -len(axis)] + (inds[-1],),
1405+
# reduced,
1406+
# inds,
1407+
# adjust_chunks=adjust_chunks,
1408+
# dtype=agg.dtype[agg.name],
1409+
# key=agg.name,
1410+
# name=f"{name}-{token}",
1411+
# )
13891412
return (result, groups)
13901413

13911414

13921415
def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray:
1416+
from dask.array.core import deepfirst
1417+
1418+
if not isinstance(result_dict, dict):
1419+
result_dict = deepfirst(result_dict)
13931420
return result_dict[key]
13941421

13951422

0 commit comments

Comments
 (0)