diff --git a/flox/core.py b/flox/core.py index 933ef8c22..07f6d0e69 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1760,7 +1760,7 @@ def groupby_reduce( assert len(groups) == 1 sorted_idx = np.argsort(groups[0]) # This optimization helps specifically with resampling - if not (sorted_idx[1:] <= sorted_idx[:-1]).all(): + if not (sorted_idx[:-1] <= sorted_idx[1:]).all(): result = result[..., sorted_idx] groups = (groups[0][sorted_idx],) diff --git a/tests/test_core.py b/tests/test_core.py index d1db6fab9..2fc534d83 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1183,3 +1183,23 @@ def test_validate_reindex(): for func in ["sum", "argmax"]: actual = _validate_reindex(None, func, method, expected_groups=None, by_is_dask=False) assert actual is False + + +@requires_dask +def test_1d_blockwise_sort_optimization(): + # Make sure for resampling problems sorting isn't done. + time = pd.Series(pd.date_range("2020-09-01", "2020-12-31 23:59", freq="3H")) + array = dask.array.ones((len(time),), chunks=(224,)) + + actual, _ = groupby_reduce(array, time.dt.dayofyear.values, method="blockwise", func="count") + assert all("getitem" not in k for k in actual.dask) + + actual, _ = groupby_reduce( + array, time.dt.dayofyear.values[::-1], sort=True, method="blockwise", func="count" + ) + assert any("getitem" in k for k in actual.dask.layers) + + actual, _ = groupby_reduce( + array, time.dt.dayofyear.values[::-1], sort=False, method="blockwise", func="count" + ) + assert all("getitem" not in k for k in actual.dask.layers)