diff --git a/flox/core.py b/flox/core.py index 6bd390137..1437e506c 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1388,27 +1388,35 @@ def dask_groupby_agg( return (result, groups) -def _validate_reindex(reindex: bool | None, func, method: T_Method, expected_groups) -> bool | None: +def _validate_reindex( + reindex: bool | None, func, method: T_Method, expected_groups, by_is_dask: bool +) -> bool: if reindex is True: if _is_arg_reduction(func): raise NotImplementedError if method == "blockwise": raise NotImplementedError - if method == "blockwise" or _is_arg_reduction(func): - reindex = False + if reindex is None: + if method == "blockwise" or _is_arg_reduction(func): + reindex = False - if reindex is None and expected_groups is not None: - reindex = True + elif expected_groups is not None: + reindex = True + + elif method in ["split-reduce", "cohorts"]: + reindex = True + + elif method == "map-reduce": + if expected_groups is None and by_is_dask: + reindex = False + else: + reindex = True if method in ["split-reduce", "cohorts"] and reindex is False: raise NotImplementedError - if method in ["split-reduce", "cohorts"] and reindex is None: - reindex = True - - # TODO: Should reindex be a bool-only at this point? Would've been nice but - # None's are relied on after this function as well. + assert isinstance(reindex, bool) return reindex @@ -1597,7 +1605,6 @@ def groupby_reduce( "argreductions not supported for engine='flox' yet." "Try engine='numpy' or engine='numba' instead." ) - reindex = _validate_reindex(reindex, func, method, expected_groups) bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) nby = len(bys) @@ -1606,6 +1613,8 @@ def groupby_reduce( if method in ["split-reduce", "cohorts"] and by_is_dask: raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.") + reindex = _validate_reindex(reindex, func, method, expected_groups, by_is_dask) + if not is_duck_array(array): array = np.asarray(array) is_bool_array = np.issubdtype(array.dtype, bool) diff --git a/tests/test_core.py b/tests/test_core.py index e31f11e56..d7cbc4d8d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,6 +1,6 @@ from __future__ import annotations -from functools import reduce +from functools import partial, reduce from typing import TYPE_CHECKING import numpy as np @@ -13,6 +13,7 @@ _convert_expected_groups_to_index, _get_optimal_chunks_for_groups, _normalize_indexes, + _validate_reindex, factorize_, find_group_cohorts, groupby_reduce, @@ -221,14 +222,26 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): if not has_dask: continue for method in ["map-reduce", "cohorts", "split-reduce"]: - if "arg" in func and method != "map-reduce": - continue - actual, *groups = groupby_reduce(array, *by, method=method, **flox_kwargs) - for actual_group, expect in zip(groups, expected_groups): - assert_equal(actual_group, expect, tolerance) - if "arg" in func: - assert actual.dtype.kind == "i" - assert_equal(actual, expected, tolerance) + if method == "map-reduce": + reindexes = [True, False, None] + else: + reindexes = [None] + for reindex in reindexes: + call = partial( + groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs + ) + if "arg" in func: + if method != "map-reduce" or reindex is True: + with pytest.raises(NotImplementedError): + call() + continue + + actual, *groups = call() + for actual_group, expect in zip(groups, expected_groups): + assert_equal(actual_group, expect, tolerance) + if "arg" in func: + assert actual.dtype.kind == "i" + assert_equal(actual, expected, tolerance) @requires_dask @@ -1125,3 +1138,33 @@ def test_subset_block_2d(flatblocks, expectidx): subset = subset_to_blocks(array, flatblocks) assert len(subset.dask.layers) == 2 assert_equal(subset, array.compute()[expectidx]) + + +@pytest.mark.parametrize("method", ["map-reduce", "cohorts"]) +@pytest.mark.parametrize( + "expected, reindex, func, expected_groups, by_is_dask", + [ + # argmax only False + [False, None, "argmax", None, False], + # True when by is numpy but expected is None + [True, None, "sum", None, False], + # False when by is dask but expected is None + [False, None, "sum", None, True], + # if expected_groups then always True + [True, None, "sum", [1, 2, 3], False], + [True, None, "sum", ([1], [2]), False], + [True, None, "sum", ([1], [2]), True], + [True, None, "sum", ([1], None), False], + [True, None, "sum", ([1], None), True], + ], +) +def test_validate_reindex(expected, reindex, func, method, expected_groups, by_is_dask): + if by_is_dask and method == "cohorts": + # This should error elsewhere + pytest.skip() + call = partial(_validate_reindex, reindex, func, method, expected_groups, by_is_dask) + if "arg" in func and method == "cohorts": + with pytest.raises(NotImplementedError): + call() + else: + assert call() == expected