diff --git a/flox/__init__.py b/flox/__init__.py index 2ca5fa5ba..839bfb076 100644 --- a/flox/__init__.py +++ b/flox/__init__.py @@ -2,7 +2,7 @@ # flake8: noqa """Top-level module for flox .""" from . import cache -from .aggregations import Aggregation # noqa +from .aggregations import Aggregation, Scan # noqa from .core import groupby_reduce, groupby_scan, rechunk_for_blockwise, rechunk_for_cohorts # noqa diff --git a/flox/core.py b/flox/core.py index 1c9599a15..d7fa5f6a0 100644 --- a/flox/core.py +++ b/flox/core.py @@ -170,7 +170,9 @@ def _is_minmax_reduction(func: T_Agg) -> bool: def _is_first_last_reduction(func: T_Agg) -> bool: - return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"] + if isinstance(func, Aggregation): + func = func.name + return func in ["nanfirst", "nanlast", "first", "last"] def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex: @@ -1642,7 +1644,12 @@ def dask_groupby_agg( # This allows us to discover groups at compute time, support argreductions, lower intermediate # memory usage (but method="cohorts" would also work to reduce memory in some cases) labels_are_unknown = is_duck_dask_array(by_input) and expected_groups is None - do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown + do_grouped_combine = ( + _is_arg_reduction(agg) + or labels_are_unknown + or (_is_first_last_reduction(agg) and array.dtype.kind != "f") + ) + do_simple_combine = not do_grouped_combine if method == "blockwise": # use the "non dask" code path, but applied blockwise @@ -1698,7 +1705,7 @@ def dask_groupby_agg( tree_reduce = partial( dask.array.reductions._tree_reduce, - name=f"{name}-reduce", + name=f"{name}-simple-reduce", dtype=array.dtype, axis=axis, keepdims=True, @@ -1733,14 +1740,20 @@ def dask_groupby_agg( groups_ = [] for blks, cohort in chunks_cohorts.items(): cohort_index = pd.Index(cohort) - reindexer = partial(reindex_intermediates, agg=agg, unique_groups=cohort_index) + reindexer = ( + partial(reindex_intermediates, agg=agg, unique_groups=cohort_index) + if do_simple_combine + else identity + ) reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer) # now that we have reindexed, we can set reindex=True explicitlly reduced_.append( tree_reduce( reindexed, - combine=partial(combine, agg=agg, reindex=True), - aggregate=partial(aggregate, expected_groups=cohort_index, reindex=True), + combine=partial(combine, agg=agg, reindex=do_simple_combine), + aggregate=partial( + aggregate, expected_groups=cohort_index, reindex=do_simple_combine + ), ) ) # This is done because pandas promotes to 64-bit types when an Index is created @@ -1986,8 +1999,13 @@ def _validate_reindex( expected_groups, any_by_dask: bool, is_dask_array: bool, + array_dtype: Any, ) -> bool | None: # logger.debug("Entering _validate_reindex: reindex is {}".format(reindex)) # noqa + def first_or_last(): + return func in ["first", "last"] or ( + _is_first_last_reduction(func) and array_dtype.kind != "f" + ) all_numpy = not is_dask_array and not any_by_dask if reindex is True and not all_numpy: @@ -1997,7 +2015,7 @@ def _validate_reindex( raise ValueError( "reindex=True is not a valid choice for method='blockwise' or method='cohorts'." ) - if func in ["first", "last"]: + if first_or_last(): raise ValueError("reindex must be None or False when func is 'first' or 'last.") if reindex is None: @@ -2008,9 +2026,10 @@ def _validate_reindex( if all_numpy: return True - if func in ["first", "last"]: + if first_or_last(): # have to do the grouped_combine since there's no good fill_value - reindex = False + # Also needed for nanfirst, nanlast with no-NaN dtypes + return False if method == "blockwise": # for grouping by dask arrays, we set reindex=True @@ -2412,12 +2431,19 @@ def groupby_reduce( if method == "cohorts" and any_by_dask: raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.") + if not is_duck_array(array): + array = np.asarray(array) + reindex = _validate_reindex( - reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array) + reindex, + func, + method, + expected_groups, + any_by_dask, + is_duck_dask_array(array), + array.dtype, ) - if not is_duck_array(array): - array = np.asarray(array) is_bool_array = np.issubdtype(array.dtype, bool) array = array.astype(np.intp) if is_bool_array else array @@ -2601,7 +2627,7 @@ def groupby_reduce( # TODO: clean this up reindex = _validate_reindex( - reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array) + reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array), array.dtype ) if TYPE_CHECKING: diff --git a/tests/conftest.py b/tests/conftest.py index b3a0ab932..4413ea1e8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,11 +10,12 @@ suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.too_slow], ) settings.register_profile( - "local", + "default", max_examples=300, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.too_slow], verbosity=Verbosity.verbose, ) +settings.load_profile("default") @pytest.fixture( diff --git a/tests/test_core.py b/tests/test_core.py index e12e695db..540e32c03 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -613,6 +613,33 @@ def test_dask_reduce_axis_subset(): ) +@pytest.mark.parametrize("group_idx", [[0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 1, 0]]) +@pytest.mark.parametrize( + "func", + [ + # "first", "last", + "nanfirst", + "nanlast", + ], +) +@pytest.mark.parametrize( + "chunks", + [ + None, + pytest.param(1, marks=pytest.mark.skipif(not has_dask, reason="no dask")), + pytest.param(2, marks=pytest.mark.skipif(not has_dask, reason="no dask")), + pytest.param(3, marks=pytest.mark.skipif(not has_dask, reason="no dask")), + ], +) +def test_first_last_useless(func, chunks, group_idx): + array = np.array([[0, 0, 0], [0, 0, 0]], dtype=np.int8) + if chunks is not None: + array = dask.array.from_array(array, chunks=chunks) + actual, _ = groupby_reduce(array, np.array(group_idx), func=func, engine="numpy") + expected = np.array([[0, 0], [0, 0]], dtype=np.int8) + assert_equal(actual, expected) + + @pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"]) @pytest.mark.parametrize("axis", [(0, 1)]) def test_first_last_disallowed(axis, func): @@ -1563,18 +1590,36 @@ def test_validate_reindex_map_reduce( dask_expected, reindex, func, expected_groups, any_by_dask ) -> None: actual = _validate_reindex( - reindex, func, "map-reduce", expected_groups, any_by_dask, is_dask_array=True + reindex, + func, + "map-reduce", + expected_groups, + any_by_dask, + is_dask_array=True, + array_dtype=np.dtype("int32"), ) assert actual is dask_expected # always reindex with all numpy inputs actual = _validate_reindex( - reindex, func, "map-reduce", expected_groups, any_by_dask=False, is_dask_array=False + reindex, + func, + "map-reduce", + expected_groups, + any_by_dask=False, + is_dask_array=False, + array_dtype=np.dtype("int32"), ) assert actual actual = _validate_reindex( - True, func, "map-reduce", expected_groups, any_by_dask=False, is_dask_array=False + True, + func, + "map-reduce", + expected_groups, + any_by_dask=False, + is_dask_array=False, + array_dtype=np.dtype("int32"), ) assert actual @@ -1584,19 +1629,37 @@ def test_validate_reindex() -> None: for method in methods: with pytest.raises(NotImplementedError): _validate_reindex( - True, "argmax", method, expected_groups=None, any_by_dask=False, is_dask_array=True + True, + "argmax", + method, + expected_groups=None, + any_by_dask=False, + is_dask_array=True, + array_dtype=np.dtype("int32"), ) methods: list[T_Method] = ["blockwise", "cohorts"] for method in methods: with pytest.raises(ValueError): _validate_reindex( - True, "sum", method, expected_groups=None, any_by_dask=False, is_dask_array=True + True, + "sum", + method, + expected_groups=None, + any_by_dask=False, + is_dask_array=True, + array_dtype=np.dtype("int32"), ) for func in ["sum", "argmax"]: actual = _validate_reindex( - None, func, method, expected_groups=None, any_by_dask=False, is_dask_array=True + None, + func, + method, + expected_groups=None, + any_by_dask=False, + is_dask_array=True, + array_dtype=np.dtype("int32"), ) assert actual is False @@ -1608,6 +1671,7 @@ def test_validate_reindex() -> None: expected_groups=np.array([1, 2, 3]), any_by_dask=False, is_dask_array=True, + array_dtype=np.dtype("int32"), ) assert _validate_reindex( @@ -1617,6 +1681,7 @@ def test_validate_reindex() -> None: expected_groups=np.array([1, 2, 3]), any_by_dask=True, is_dask_array=True, + array_dtype=np.dtype("int32"), ) assert _validate_reindex( None, @@ -1625,8 +1690,24 @@ def test_validate_reindex() -> None: expected_groups=np.array([1, 2, 3]), any_by_dask=True, is_dask_array=True, + array_dtype=np.dtype("int32"), + ) + + kwargs = dict( + method="blockwise", + expected_groups=np.array([1, 2, 3]), + any_by_dask=True, + is_dask_array=True, ) + for func in ["nanfirst", "nanlast"]: + assert not _validate_reindex(None, func, array_dtype=np.dtype("int32"), **kwargs) # type: ignore[arg-type] + assert _validate_reindex(None, func, array_dtype=np.dtype("float32"), **kwargs) # type: ignore[arg-type] + + for func in ["first", "last"]: + assert not _validate_reindex(None, func, array_dtype=np.dtype("int32"), **kwargs) # type: ignore[arg-type] + assert not _validate_reindex(None, func, array_dtype=np.dtype("float32"), **kwargs) # type: ignore[arg-type] + @requires_dask def test_1d_blockwise_sort_optimization(): diff --git a/tests/test_properties.py b/tests/test_properties.py index 6fef85b3a..c032f0742 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -9,6 +9,7 @@ pytest.importorskip("cftime") import dask +import hypothesis.extra.numpy as npst import hypothesis.strategies as st import numpy as np from hypothesis import assume, given, note @@ -19,6 +20,7 @@ from . import assert_equal from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays +from .strategies import chunks as chunks_strategy dask.config.set(scheduler="sync") @@ -208,3 +210,16 @@ def test_first_last(data, array: dask.array.Array, func: str) -> None: first, *_ = groupby_reduce(array, by, func=func, engine="flox") second, *_ = groupby_reduce(array, by, func=mate, engine="flox") assert_equal(first, second) + + +@given(data=st.data(), func=st.sampled_from(["nanfirst", "nanlast"])) +def test_first_last_useless(data, func): + shape = data.draw(npst.array_shapes()) + by = data.draw(by_arrays(shape=shape[slice(-1, None)])) + chunks = data.draw(chunks_strategy(shape=shape)) + array = np.zeros(shape, dtype=np.int8) + if chunks is not None: + array = dask.array.from_array(array, chunks=chunks) + actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy") + expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype) + assert_equal(actual, expected)