Skip to content

Support quantile, median, mode with method="blockwise". #269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions flox/aggregate_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,51 @@ def _len(group_idx, array, engine, *, func, axis=-1, size=None, fill_value=None,

len = partial(_len, func="len")
nanlen = partial(_len, func="nanlen")


def median(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=np.median,
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)


def nanmedian(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=np.nanmedian,
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)


def quantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=partial(np.quantile, q=q),
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)


def nanquantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=partial(np.nanquantile, q=q),
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)
10 changes: 8 additions & 2 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,10 @@ def _pick_second(*x):

# numpy_groupies does not support median
# And the dask version is really hard!
# median = Aggregation("median", chunk=None, combine=None, fill_value=None)
# nanmedian = Aggregation("nanmedian", chunk=None, combine=None, fill_value=None)
median = Aggregation(name="median", fill_value=-1, chunk=None, combine=None)
nanmedian = Aggregation(name="nanmedian", fill_value=-1, chunk=None, combine=None)
quantile = Aggregation(name="quantile", fill_value=-1, chunk=None, combine=None)
nanquantile = Aggregation(name="nanquantile", fill_value=-1, chunk=None, combine=None)

aggregations = {
"any": any_,
Expand Down Expand Up @@ -496,6 +498,10 @@ def _pick_second(*x):
"nanfirst": nanfirst,
"last": last,
"nanlast": nanlast,
"median": median,
"nanmedian": nanmedian,
"quantile": quantile,
"nanquantile": nanquantile,
}


Expand Down
28 changes: 17 additions & 11 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,9 +1307,6 @@ def dask_groupby_agg(
assert isinstance(axis, Sequence)
assert all(ax >= 0 for ax in axis)

if method == "blockwise" and not isinstance(by, np.ndarray):
raise NotImplementedError

inds = tuple(range(array.ndim))
name = f"groupby_{agg.name}"
token = dask.base.tokenize(array, by, agg, expected_groups, axis)
Expand Down Expand Up @@ -1471,11 +1468,15 @@ def dask_groupby_agg(
# Here one input chunk → one output chunks
# find number of groups in each chunk, this is needed for output chunks
# along the reduced axis
slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis))
groups_in_block = tuple(_unique(by_input[slc]) for slc in slices)
groups = (np.concatenate(groups_in_block),)
ngroups_per_block = tuple(len(grp) for grp in groups_in_block)
group_chunks = (ngroups_per_block,)
if isinstance(by, dask.array.Array):
groups = (expected_groups,)
group_chunks = ((len(expected_groups),),)
else:
slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis))
groups_in_block = tuple(_unique(by_input[slc]) for slc in slices)
groups = (np.concatenate(groups_in_block),)
ngroups_per_block = tuple(len(grp) for grp in groups_in_block)
group_chunks = (ngroups_per_block,)
else:
raise ValueError(f"Unknown method={method}.")

Expand Down Expand Up @@ -1544,6 +1545,7 @@ def _validate_reindex(
is_dask_array: bool,
) -> bool:
all_numpy = not is_dask_array and not any_by_dask

if reindex is True and not all_numpy:
if _is_arg_reduction(func):
raise NotImplementedError
Expand All @@ -1562,7 +1564,11 @@ def _validate_reindex(
# have to do the grouped_combine since there's no good fill_value
reindex = False

if method == "blockwise" or _is_arg_reduction(func):
if method == "blockwise":
# for grouping by dask arrays, we set reindex=True
reindex = any_by_dask

elif _is_arg_reduction(func):
reindex = False

elif method == "cohorts":
Expand Down Expand Up @@ -1835,7 +1841,7 @@ def groupby_reduce(
boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the
original block size. Avoid that by using ``method="cohorts"``. By default, it is turned off for argreductions.
finalize_kwargs : dict, optional
Kwargs passed to finalize the reduction such as ``ddof`` for var, std.
Kwargs passed to finalize the reduction such as ``ddof`` for var, std or ``q`` for quantile.

Returns
-------
Expand Down Expand Up @@ -2023,7 +2029,7 @@ def groupby_reduce(
result, groups = partial_agg(
array,
by_,
expected_groups=None if method == "blockwise" else expected_groups,
expected_groups=expected_groups,
agg=agg,
reindex=reindex,
method=method,
Expand Down
2 changes: 1 addition & 1 deletion flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def xarray_reduce(
boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the
original block size. Avoid that by using method="cohorts". By default, it is turned off for arg reductions.
**finalize_kwargs
kwargs passed to the finalize function, like ``ddof`` for var, std.
kwargs passed to the finalize function, like ``ddof`` for var, std or ``q`` for quantile.

Returns
-------
Expand Down