Skip to content

Fix first, last again #381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# flake8: noqa
"""Top-level module for flox ."""
from . import cache
from .aggregations import Aggregation # noqa
from .aggregations import Aggregation, Scan # noqa
from .core import groupby_reduce, groupby_scan, rechunk_for_blockwise, rechunk_for_cohorts # noqa


Expand Down
52 changes: 39 additions & 13 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def _is_minmax_reduction(func: T_Agg) -> bool:


def _is_first_last_reduction(func: T_Agg) -> bool:
return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"]
if isinstance(func, Aggregation):
func = func.name
return func in ["nanfirst", "nanlast", "first", "last"]


def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
Expand Down Expand Up @@ -1642,7 +1644,12 @@ def dask_groupby_agg(
# This allows us to discover groups at compute time, support argreductions, lower intermediate
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
labels_are_unknown = is_duck_dask_array(by_input) and expected_groups is None
do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown
do_grouped_combine = (
_is_arg_reduction(agg)
or labels_are_unknown
or (_is_first_last_reduction(agg) and array.dtype.kind != "f")
)
do_simple_combine = not do_grouped_combine

if method == "blockwise":
# use the "non dask" code path, but applied blockwise
Expand Down Expand Up @@ -1698,7 +1705,7 @@ def dask_groupby_agg(

tree_reduce = partial(
dask.array.reductions._tree_reduce,
name=f"{name}-reduce",
name=f"{name}-simple-reduce",
dtype=array.dtype,
axis=axis,
keepdims=True,
Expand Down Expand Up @@ -1733,14 +1740,20 @@ def dask_groupby_agg(
groups_ = []
for blks, cohort in chunks_cohorts.items():
cohort_index = pd.Index(cohort)
reindexer = partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
reindexer = (
partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
if do_simple_combine
else identity
)
reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer)
# now that we have reindexed, we can set reindex=True explicitlly
reduced_.append(
tree_reduce(
reindexed,
combine=partial(combine, agg=agg, reindex=True),
aggregate=partial(aggregate, expected_groups=cohort_index, reindex=True),
combine=partial(combine, agg=agg, reindex=do_simple_combine),
aggregate=partial(
aggregate, expected_groups=cohort_index, reindex=do_simple_combine
),
)
)
# This is done because pandas promotes to 64-bit types when an Index is created
Expand Down Expand Up @@ -1986,8 +1999,13 @@ def _validate_reindex(
expected_groups,
any_by_dask: bool,
is_dask_array: bool,
array_dtype: Any,
) -> bool | None:
# logger.debug("Entering _validate_reindex: reindex is {}".format(reindex)) # noqa
def first_or_last():
return func in ["first", "last"] or (
_is_first_last_reduction(func) and array_dtype.kind != "f"
)

all_numpy = not is_dask_array and not any_by_dask
if reindex is True and not all_numpy:
Expand All @@ -1997,7 +2015,7 @@ def _validate_reindex(
raise ValueError(
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
)
if func in ["first", "last"]:
if first_or_last():
raise ValueError("reindex must be None or False when func is 'first' or 'last.")

if reindex is None:
Expand All @@ -2008,9 +2026,10 @@ def _validate_reindex(
if all_numpy:
return True

if func in ["first", "last"]:
if first_or_last():
# have to do the grouped_combine since there's no good fill_value
reindex = False
# Also needed for nanfirst, nanlast with no-NaN dtypes
return False

if method == "blockwise":
# for grouping by dask arrays, we set reindex=True
Expand Down Expand Up @@ -2412,12 +2431,19 @@ def groupby_reduce(
if method == "cohorts" and any_by_dask:
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")

if not is_duck_array(array):
array = np.asarray(array)

reindex = _validate_reindex(
reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array)
reindex,
func,
method,
expected_groups,
any_by_dask,
is_duck_dask_array(array),
array.dtype,
)

if not is_duck_array(array):
array = np.asarray(array)
is_bool_array = np.issubdtype(array.dtype, bool)
array = array.astype(np.intp) if is_bool_array else array

Expand Down Expand Up @@ -2601,7 +2627,7 @@ def groupby_reduce(

# TODO: clean this up
reindex = _validate_reindex(
reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array)
reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array), array.dtype
)

if TYPE_CHECKING:
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.too_slow],
)
settings.register_profile(
"local",
"default",
max_examples=300,
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.too_slow],
verbosity=Verbosity.verbose,
)
settings.load_profile("default")


@pytest.fixture(
Expand Down
93 changes: 87 additions & 6 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,33 @@ def test_dask_reduce_axis_subset():
)


@pytest.mark.parametrize("group_idx", [[0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 1, 0]])
@pytest.mark.parametrize(
"func",
[
# "first", "last",
"nanfirst",
"nanlast",
],
)
@pytest.mark.parametrize(
"chunks",
[
None,
pytest.param(1, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
pytest.param(2, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
pytest.param(3, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
],
)
def test_first_last_useless(func, chunks, group_idx):
array = np.array([[0, 0, 0], [0, 0, 0]], dtype=np.int8)
if chunks is not None:
array = dask.array.from_array(array, chunks=chunks)
actual, _ = groupby_reduce(array, np.array(group_idx), func=func, engine="numpy")
expected = np.array([[0, 0], [0, 0]], dtype=np.int8)
assert_equal(actual, expected)


@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"])
@pytest.mark.parametrize("axis", [(0, 1)])
def test_first_last_disallowed(axis, func):
Expand Down Expand Up @@ -1563,18 +1590,36 @@ def test_validate_reindex_map_reduce(
dask_expected, reindex, func, expected_groups, any_by_dask
) -> None:
actual = _validate_reindex(
reindex, func, "map-reduce", expected_groups, any_by_dask, is_dask_array=True
reindex,
func,
"map-reduce",
expected_groups,
any_by_dask,
is_dask_array=True,
array_dtype=np.dtype("int32"),
)
assert actual is dask_expected

# always reindex with all numpy inputs
actual = _validate_reindex(
reindex, func, "map-reduce", expected_groups, any_by_dask=False, is_dask_array=False
reindex,
func,
"map-reduce",
expected_groups,
any_by_dask=False,
is_dask_array=False,
array_dtype=np.dtype("int32"),
)
assert actual

actual = _validate_reindex(
True, func, "map-reduce", expected_groups, any_by_dask=False, is_dask_array=False
True,
func,
"map-reduce",
expected_groups,
any_by_dask=False,
is_dask_array=False,
array_dtype=np.dtype("int32"),
)
assert actual

Expand All @@ -1584,19 +1629,37 @@ def test_validate_reindex() -> None:
for method in methods:
with pytest.raises(NotImplementedError):
_validate_reindex(
True, "argmax", method, expected_groups=None, any_by_dask=False, is_dask_array=True
True,
"argmax",
method,
expected_groups=None,
any_by_dask=False,
is_dask_array=True,
array_dtype=np.dtype("int32"),
)

methods: list[T_Method] = ["blockwise", "cohorts"]
for method in methods:
with pytest.raises(ValueError):
_validate_reindex(
True, "sum", method, expected_groups=None, any_by_dask=False, is_dask_array=True
True,
"sum",
method,
expected_groups=None,
any_by_dask=False,
is_dask_array=True,
array_dtype=np.dtype("int32"),
)

for func in ["sum", "argmax"]:
actual = _validate_reindex(
None, func, method, expected_groups=None, any_by_dask=False, is_dask_array=True
None,
func,
method,
expected_groups=None,
any_by_dask=False,
is_dask_array=True,
array_dtype=np.dtype("int32"),
)
assert actual is False

Expand All @@ -1608,6 +1671,7 @@ def test_validate_reindex() -> None:
expected_groups=np.array([1, 2, 3]),
any_by_dask=False,
is_dask_array=True,
array_dtype=np.dtype("int32"),
)

assert _validate_reindex(
Expand All @@ -1617,6 +1681,7 @@ def test_validate_reindex() -> None:
expected_groups=np.array([1, 2, 3]),
any_by_dask=True,
is_dask_array=True,
array_dtype=np.dtype("int32"),
)
assert _validate_reindex(
None,
Expand All @@ -1625,8 +1690,24 @@ def test_validate_reindex() -> None:
expected_groups=np.array([1, 2, 3]),
any_by_dask=True,
is_dask_array=True,
array_dtype=np.dtype("int32"),
)

kwargs = dict(
method="blockwise",
expected_groups=np.array([1, 2, 3]),
any_by_dask=True,
is_dask_array=True,
)

for func in ["nanfirst", "nanlast"]:
assert not _validate_reindex(None, func, array_dtype=np.dtype("int32"), **kwargs) # type: ignore[arg-type]
assert _validate_reindex(None, func, array_dtype=np.dtype("float32"), **kwargs) # type: ignore[arg-type]

for func in ["first", "last"]:
assert not _validate_reindex(None, func, array_dtype=np.dtype("int32"), **kwargs) # type: ignore[arg-type]
assert not _validate_reindex(None, func, array_dtype=np.dtype("float32"), **kwargs) # type: ignore[arg-type]


@requires_dask
def test_1d_blockwise_sort_optimization():
Expand Down
15 changes: 15 additions & 0 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
pytest.importorskip("cftime")

import dask
import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
import numpy as np
from hypothesis import assume, given, note
Expand All @@ -19,6 +20,7 @@

from . import assert_equal
from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays
from .strategies import chunks as chunks_strategy

dask.config.set(scheduler="sync")

Expand Down Expand Up @@ -208,3 +210,16 @@ def test_first_last(data, array: dask.array.Array, func: str) -> None:
first, *_ = groupby_reduce(array, by, func=func, engine="flox")
second, *_ = groupby_reduce(array, by, func=mate, engine="flox")
assert_equal(first, second)


@given(data=st.data(), func=st.sampled_from(["nanfirst", "nanlast"]))
def test_first_last_useless(data, func):
shape = data.draw(npst.array_shapes())
by = data.draw(by_arrays(shape=shape[slice(-1, None)]))
chunks = data.draw(chunks_strategy(shape=shape))
array = np.zeros(shape, dtype=np.int8)
if chunks is not None:
array = dask.array.from_array(array, chunks=chunks)
actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy")
expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype)
assert_equal(actual, expected)
Loading